Skip to content

Commit

Permalink
Fixed bug in AnalyticDB Vector Store caused by upgrade SQLAlchemy ver…
Browse files Browse the repository at this point in the history
…sion (#6736)
  • Loading branch information
wangxuqi committed Jun 26, 2023
1 parent d84a3bc commit ec8247e
Showing 1 changed file with 52 additions and 54 deletions.
106 changes: 52 additions & 54 deletions langchain/vectorstores/analyticdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,34 +80,34 @@ def create_table_if_not_exists(self) -> None:
extend_existing=True,
)
with self.engine.connect() as conn:
# Create the table
Base.metadata.create_all(conn)

# Check if the index exists
index_name = f"{self.collection_name}_embedding_idx"
index_query = text(
f"""
SELECT 1
FROM pg_indexes
WHERE indexname = '{index_name}';
"""
)
result = conn.execute(index_query).scalar()
with conn.begin():
# Create the table
Base.metadata.create_all(conn)

# Create the index if it doesn't exist
if not result:
index_statement = text(
# Check if the index exists
index_name = f"{self.collection_name}_embedding_idx"
index_query = text(
f"""
CREATE INDEX {index_name}
ON {self.collection_name} USING ann(embedding)
WITH (
"dim" = {self.embedding_dimension},
"hnsw_m" = 100
);
SELECT 1
FROM pg_indexes
WHERE indexname = '{index_name}';
"""
)
conn.execute(index_statement)
conn.commit()
result = conn.execute(index_query).scalar()

# Create the index if it doesn't exist
if not result:
index_statement = text(
f"""
CREATE INDEX {index_name}
ON {self.collection_name} USING ann(embedding)
WITH (
"dim" = {self.embedding_dimension},
"hnsw_m" = 100
);
"""
)
conn.execute(index_statement)

def create_collection(self) -> None:
if self.pre_delete_collection:
Expand All @@ -118,8 +118,8 @@ def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection")
drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};")
with self.engine.connect() as conn:
conn.execute(drop_statement)
conn.commit()
with conn.begin():
conn.execute(drop_statement)

def add_texts(
self,
Expand Down Expand Up @@ -160,30 +160,28 @@ def add_texts(

chunks_table_data = []
with self.engine.connect() as conn:
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)

# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == batch_size:
with conn.begin():
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)

# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == batch_size:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()

# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()

# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))

# Commit the transaction only once after all records have been inserted
conn.commit()

return ids

Expand Down Expand Up @@ -333,9 +331,9 @@ def from_texts(
) -> AnalyticDB:
"""
Return VectorStore initialized from texts and embeddings.
Postgres connection string is required
Postgres Connection string is required
Either pass it as a parameter
or set the PGVECTOR_CONNECTION_STRING environment variable.
or set the PG_CONNECTION_STRING environment variable.
"""

connection_string = cls.get_connection_string(kwargs)
Expand Down Expand Up @@ -363,7 +361,7 @@ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
raise ValueError(
"Postgres connection string is required"
"Either pass it as a parameter"
"or set the PGVECTOR_CONNECTION_STRING environment variable."
"or set the PG_CONNECTION_STRING environment variable."
)

return connection_string
Expand All @@ -381,9 +379,9 @@ def from_documents(
) -> AnalyticDB:
"""
Return VectorStore initialized from documents and embeddings.
Postgres connection string is required
Postgres Connection string is required
Either pass it as a parameter
or set the PGVECTOR_CONNECTION_STRING environment variable.
or set the PG_CONNECTION_STRING environment variable.
"""

texts = [d.page_content for d in documents]
Expand Down

0 comments on commit ec8247e

Please sign in to comment.