Skip to content

Commit

Permalink
feat: support saving with customized content column and saving/loadin…
Browse files Browse the repository at this point in the history
…g with non-default metadata JSON column. (#19)
  • Loading branch information
loeng2023 committed Feb 15, 2024
1 parent 5aecbd0 commit b489a43
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 40 deletions.
21 changes: 15 additions & 6 deletions src/langchain_google_cloud_sql_mysql/mysql_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ def init_document_table(
self,
table_name: str,
metadata_columns: List[sqlalchemy.Column] = [],
store_metadata: bool = True,
content_column: str = "page_content",
metadata_json_column: Optional[str] = "langchain_metadata",
overwrite_existing: bool = False,
) -> None:
"""
Create a table for saving of langchain documents.
Expand All @@ -242,22 +244,29 @@ def init_document_table(
table_name (str): The MySQL database table name.
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
to create for custom metadata. Optional.
store_metadata (bool): Whether to store extra metadata in a metadata column
if not described in 'metadata' field list (Default: True).
content_column (str): The column to store document content.
Deafult: `page_content`.
metadata_json_column (Optional[str]): The column to store extra metadata in JSON format.
Default: `langchain_metadata`. Optional.
overwrite_existing (bool): Whether to drop existing table. Default: False.
"""
if overwrite_existing:
with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`;"))

columns = [
sqlalchemy.Column(
"page_content",
content_column,
sqlalchemy.UnicodeText,
primary_key=False,
nullable=False,
)
]
columns += metadata_columns
if store_metadata:
if metadata_json_column:
columns.append(
sqlalchemy.Column(
"langchain_metadata",
metadata_json_column,
sqlalchemy.JSON,
primary_key=False,
nullable=True,
Expand Down
94 changes: 79 additions & 15 deletions src/langchain_google_cloud_sql_mysql/mysql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,41 @@


def _parse_doc_from_row(
content_columns: Iterable[str], metadata_columns: Iterable[str], row: Dict
content_columns: Iterable[str],
metadata_columns: Iterable[str],
row: Dict,
metadata_json_column: str = DEFAULT_METADATA_COL,
) -> Document:
page_content = " ".join(
str(row[column]) for column in content_columns if column in row
)
metadata: Dict[str, Any] = {}
# unnest metadata from langchain_metadata column
if DEFAULT_METADATA_COL in metadata_columns and row.get(DEFAULT_METADATA_COL):
for k, v in row[DEFAULT_METADATA_COL].items():
if row.get(metadata_json_column):
for k, v in row[metadata_json_column].items():
metadata[k] = v
# load metadata from other columns
for column in metadata_columns:
if column in row and column != DEFAULT_METADATA_COL:
if column in row and column != metadata_json_column:
metadata[column] = row[column]
return Document(page_content=page_content, metadata=metadata)


def _parse_row_from_doc(column_names: Iterable[str], doc: Document) -> Dict:
def _parse_row_from_doc(
column_names: Iterable[str],
doc: Document,
content_column: str = DEFAULT_CONTENT_COL,
metadata_json_column: str = DEFAULT_METADATA_COL,
) -> Dict:
doc_metadata = doc.metadata.copy()
row: Dict[str, Any] = {DEFAULT_CONTENT_COL: doc.page_content}
row: Dict[str, Any] = {content_column: doc.page_content}
for entry in doc.metadata:
if entry in column_names:
row[entry] = doc_metadata[entry]
del doc_metadata[entry]
# store extra metadata in langchain_metadata column in json format
if DEFAULT_METADATA_COL in column_names and len(doc_metadata) > 0:
row[DEFAULT_METADATA_COL] = doc_metadata
if metadata_json_column in column_names and len(doc_metadata) > 0:
row[metadata_json_column] = doc_metadata
return row


Expand All @@ -67,6 +75,7 @@ def __init__(
query: str = "",
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
metadata_json_column: Optional[str] = None,
):
"""
Document page content defaults to the first column present in the query or table and
Expand All @@ -85,12 +94,15 @@ def __init__(
of the document. Optional.
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
Optional.
metadata_json_column (str): The name of the JSON column to use as the metadata’s base
dictionary. Default: `langchain_metadata`. Optional.
"""
self.engine = engine
self.table_name = table_name
self.query = query
self.content_columns = content_columns
self.metadata_columns = metadata_columns
self.metadata_json_column = metadata_json_column
if not self.table_name and not self.query:
raise ValueError("One of 'table_name' or 'query' must be specified.")
if self.table_name and self.query:
Expand Down Expand Up @@ -139,6 +151,25 @@ def lazy_load(self) -> Iterator[Document]:
metadata_columns = self.metadata_columns or [
col for col in column_names if col not in content_columns
]
# check validity of metadata json column
if (
self.metadata_json_column
and self.metadata_json_column not in column_names
):
raise ValueError(
f"Column {self.metadata_json_column} not found in query result {column_names}."
)
# check validity of other column
all_names = content_columns + metadata_columns
for name in all_names:
if name not in column_names:
raise ValueError(
f"Column {name} not found in query result {column_names}."
)
# use default metadata json column if not specified
metadata_json_column = self.metadata_json_column or DEFAULT_METADATA_COL

# load document one by one
while True:
row = result_proxy.fetchone()
if not row:
Expand All @@ -151,7 +182,12 @@ def lazy_load(self) -> Iterator[Document]:
row_data[column] = json.loads(value)
else:
row_data[column] = value
yield _parse_doc_from_row(content_columns, metadata_columns, row_data)
yield _parse_doc_from_row(
content_columns,
metadata_columns,
row_data,
metadata_json_column,
)


class MySQLDocumentSaver:
Expand All @@ -161,6 +197,8 @@ def __init__(
self,
engine: MySQLEngine,
table_name: str,
content_column: Optional[str] = None,
metadata_json_column: Optional[str] = None,
):
"""
MySQLDocumentSaver allows for saving of langchain documents in a database. If the table
Expand All @@ -169,17 +207,33 @@ def __init__(
- langchain_metadata (type: JSON)
Args:
engine: MySQLEngine object to connect to the MySQL database.
table_name: The name of table for saving documents.
engine (MySQLEngine): MySQLEngine object to connect to the MySQL database.
table_name (str): The name of table for saving documents.
content_column (str): The column to store document content.
Deafult: `page_content`. Optional.
metadata_json_column (str): The name of the JSON column to use as the metadata’s base
dictionary. Default: `langchain_metadata`. Optional.
"""
self.engine = engine
self.table_name = table_name
self._table = self.engine._load_document_table(table_name)
if DEFAULT_CONTENT_COL not in self._table.columns.keys():

self.content_column = content_column or DEFAULT_CONTENT_COL
if self.content_column not in self._table.columns.keys():
raise ValueError(
f"Missing '{DEFAULT_CONTENT_COL}' field in table {table_name}."
f"Missing '{self.content_column}' field in table {table_name}."
)

# check metadata_json_column existence if it's provided.
if (
metadata_json_column
and metadata_json_column not in self._table.columns.keys()
):
raise ValueError(
f"Cannot find '{metadata_json_column}' column in table {table_name}."
)
self.metadata_json_column = metadata_json_column or DEFAULT_METADATA_COL

def add_documents(self, docs: List[Document]) -> None:
"""
Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or
Expand All @@ -190,7 +244,12 @@ def add_documents(self, docs: List[Document]) -> None:
"""
with self.engine.connect() as conn:
for doc in docs:
row = _parse_row_from_doc(self._table.columns.keys(), doc)
row = _parse_row_from_doc(
self._table.columns.keys(),
doc,
self.content_column,
self.metadata_json_column,
)
conn.execute(sqlalchemy.insert(self._table).values(row))
conn.commit()

Expand All @@ -204,7 +263,12 @@ def delete(self, docs: List[Document]) -> None:
"""
with self.engine.connect() as conn:
for doc in docs:
row = _parse_row_from_doc(self._table.columns.keys(), doc)
row = _parse_row_from_doc(
self._table.columns.keys(),
doc,
self.content_column,
self.metadata_json_column,
)
# delete by matching all fields of document
where_conditions = []
for col in self._table.columns:
Expand Down
54 changes: 37 additions & 17 deletions tests/integration/test_mysql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ def test_load_from_query_with_langchain_metadata(engine):
query=query,
metadata_columns=[
"fruit_name",
"langchain_metadata",
],
)

Expand Down Expand Up @@ -294,8 +293,9 @@ def test_save_doc_with_default_metadata(engine):
]


@pytest.mark.parametrize("store_metadata", [True, False])
def test_save_doc_with_customized_metadata(engine, store_metadata):
@pytest.mark.parametrize("metadata_json_column", [None, "metadata_col_test"])
def test_save_doc_with_customized_metadata(engine, metadata_json_column):
content_column = "content_col_test"
engine.init_document_table(
table_name,
metadata_columns=[
Expand All @@ -312,35 +312,43 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
nullable=True,
),
],
store_metadata=store_metadata,
content_column=content_column,
metadata_json_column=metadata_json_column,
overwrite_existing=True,
)
test_docs = [
Document(
page_content="Granny Smith 150 0.99",
metadata={"fruit_id": 1, "fruit_name": "Apple", "organic": 1},
),
]
saver = MySQLDocumentSaver(engine=engine, table_name=table_name)
saver = MySQLDocumentSaver(
engine=engine,
table_name=table_name,
content_column=content_column,
metadata_json_column=metadata_json_column,
)
loader = MySQLLoader(
engine=engine,
table_name=table_name,
content_columns=[content_column],
metadata_columns=[
"fruit_id",
"fruit_name",
"organic",
],
metadata_json_column=metadata_json_column,
)

saver.add_documents(test_docs)
docs = loader.load()

if store_metadata:
if metadata_json_column:
docs == test_docs
assert engine._load_document_table(table_name).columns.keys() == [
"page_content",
content_column,
"fruit_name",
"organic",
"langchain_metadata",
metadata_json_column,
]
else:
assert docs == [
Expand All @@ -350,7 +358,7 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
),
]
assert engine._load_document_table(table_name).columns.keys() == [
"page_content",
content_column,
"fruit_name",
"organic",
]
Expand All @@ -359,7 +367,7 @@ def test_save_doc_with_customized_metadata(engine, store_metadata):
def test_save_doc_without_metadata(engine):
engine.init_document_table(
table_name,
store_metadata=False,
metadata_json_column=None,
)
test_docs = [
Document(
Expand Down Expand Up @@ -413,8 +421,9 @@ def test_delete_doc_with_default_metadata(engine):
assert len(loader.load()) == 0


@pytest.mark.parametrize("store_metadata", [True, False])
def test_delete_doc_with_customized_metadata(engine, store_metadata):
@pytest.mark.parametrize("metadata_json_column", [None, "metadata_col_test"])
def test_delete_doc_with_customized_metadata(engine, metadata_json_column):
content_column = "content_col_test"
engine.init_document_table(
table_name,
metadata_columns=[
Expand All @@ -431,7 +440,9 @@ def test_delete_doc_with_customized_metadata(engine, store_metadata):
nullable=True,
),
],
store_metadata=store_metadata,
content_column=content_column,
metadata_json_column=metadata_json_column,
overwrite_existing=True,
)
test_docs = [
Document(
Expand All @@ -443,8 +454,18 @@ def test_delete_doc_with_customized_metadata(engine, store_metadata):
metadata={"fruit_id": 2, "fruit_name": "Banana", "organic": 1},
),
]
saver = MySQLDocumentSaver(engine=engine, table_name=table_name)
loader = MySQLLoader(engine=engine, table_name=table_name)
saver = MySQLDocumentSaver(
engine=engine,
table_name=table_name,
content_column=content_column,
metadata_json_column=metadata_json_column,
)
loader = MySQLLoader(
engine=engine,
table_name=table_name,
content_columns=[content_column],
metadata_json_column=metadata_json_column,
)

saver.add_documents(test_docs)
docs = loader.load()
Expand Down Expand Up @@ -474,7 +495,6 @@ def test_delete_doc_with_query(engine):
nullable=True,
),
],
store_metadata=True,
)
test_docs = [
Document(
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_doc2row.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,13 @@ def test_row2doc_ovrride_default_metadata():


def test_row2doc_metadata_col_nonexist():
assert _parse_doc_from_row(
doc = _parse_doc_from_row(
["variety", "quantity_in_stock", "price_per_unit"],
["fruit-id"],
row_customized_nested,
) == Document(page_content="Granny Smith 150 0.99")
metadata_json_column="non-exist",
)
assert doc == Document(page_content="Granny Smith 150 0.99")


def test_doc2row_default():
Expand Down

0 comments on commit b489a43

Please sign in to comment.