From ec8247ec5979ff5f8f31f88a797a3e8dfa48b915 Mon Sep 17 00:00:00 2001 From: Richy Wang Date: Mon, 26 Jun 2023 20:35:25 +0800 Subject: [PATCH] Fixed bug in AnalyticDB Vector Store caused by upgrade SQLAlchemy version (#6736) --- langchain/vectorstores/analyticdb.py | 106 +++++++++++++-------------- 1 file changed, 52 insertions(+), 54 deletions(-) diff --git a/langchain/vectorstores/analyticdb.py b/langchain/vectorstores/analyticdb.py index 5d422a3beb1e8c..385d666f1ec203 100644 --- a/langchain/vectorstores/analyticdb.py +++ b/langchain/vectorstores/analyticdb.py @@ -80,34 +80,34 @@ def create_table_if_not_exists(self) -> None: extend_existing=True, ) with self.engine.connect() as conn: - # Create the table - Base.metadata.create_all(conn) - - # Check if the index exists - index_name = f"{self.collection_name}_embedding_idx" - index_query = text( - f""" - SELECT 1 - FROM pg_indexes - WHERE indexname = '{index_name}'; - """ - ) - result = conn.execute(index_query).scalar() + with conn.begin(): + # Create the table + Base.metadata.create_all(conn) - # Create the index if it doesn't exist - if not result: - index_statement = text( + # Check if the index exists + index_name = f"{self.collection_name}_embedding_idx" + index_query = text( f""" - CREATE INDEX {index_name} - ON {self.collection_name} USING ann(embedding) - WITH ( - "dim" = {self.embedding_dimension}, - "hnsw_m" = 100 - ); + SELECT 1 + FROM pg_indexes + WHERE indexname = '{index_name}'; """ ) - conn.execute(index_statement) - conn.commit() + result = conn.execute(index_query).scalar() + + # Create the index if it doesn't exist + if not result: + index_statement = text( + f""" + CREATE INDEX {index_name} + ON {self.collection_name} USING ann(embedding) + WITH ( + "dim" = {self.embedding_dimension}, + "hnsw_m" = 100 + ); + """ + ) + conn.execute(index_statement) def create_collection(self) -> None: if self.pre_delete_collection: @@ -118,8 +118,8 @@ def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};") with self.engine.connect() as conn: - conn.execute(drop_statement) - conn.commit() + with conn.begin(): + conn.execute(drop_statement) def add_texts( self, @@ -160,30 +160,28 @@ def add_texts( chunks_table_data = [] with self.engine.connect() as conn: - for document, metadata, chunk_id, embedding in zip( - texts, metadatas, ids, embeddings - ): - chunks_table_data.append( - { - "id": chunk_id, - "embedding": embedding, - "document": document, - "metadata": metadata, - } - ) - - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == batch_size: + with conn.begin(): + for document, metadata, chunk_id, embedding in zip( + texts, metadatas, ids, embeddings + ): + chunks_table_data.append( + { + "id": chunk_id, + "embedding": embedding, + "document": document, + "metadata": metadata, + } + ) + + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == batch_size: + conn.execute(insert(chunks_table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: conn.execute(insert(chunks_table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: - conn.execute(insert(chunks_table).values(chunks_table_data)) - - # Commit the transaction only once after all records have been inserted - conn.commit() return ids @@ -333,9 +331,9 @@ def from_texts( ) -> AnalyticDB: """ Return VectorStore initialized from texts and embeddings. - Postgres connection string is required + Postgres Connection string is required Either pass it as a parameter - or set the PGVECTOR_CONNECTION_STRING environment variable. + or set the PG_CONNECTION_STRING environment variable. """ connection_string = cls.get_connection_string(kwargs) @@ -363,7 +361,7 @@ def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: raise ValueError( "Postgres connection string is required" "Either pass it as a parameter" - "or set the PGVECTOR_CONNECTION_STRING environment variable." + "or set the PG_CONNECTION_STRING environment variable." ) return connection_string @@ -381,9 +379,9 @@ def from_documents( ) -> AnalyticDB: """ Return VectorStore initialized from documents and embeddings. - Postgres connection string is required + Postgres Connection string is required Either pass it as a parameter - or set the PGVECTOR_CONNECTION_STRING environment variable. + or set the PG_CONNECTION_STRING environment variable. """ texts = [d.page_content for d in documents]