Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement vector length definition at init time in PGVector for indexing #16133

Merged
merged 2 commits into from
Jan 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions libs/community/langchain_community/vectorstores/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class BaseModel(Base):
_classes: Any = None


def _get_embedding_collection_store() -> Any:
def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any:
global _classes
if _classes is not None:
return _classes
Expand Down Expand Up @@ -125,7 +125,7 @@ class EmbeddingStore(BaseModel):
)
collection = relationship(CollectionStore, back_populates="embeddings")

embedding: Vector = sqlalchemy.Column(Vector(None))
embedding: Vector = sqlalchemy.Column(Vector(vector_dimension))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True)

Expand All @@ -151,6 +151,10 @@ class PGVector(VectorStore):
connection_string: Postgres connection string.
embedding_function: Any embedding function implementing
`langchain.embeddings.base.Embeddings` interface.
embedding_length: The length of the embedding vector. (default: None)
NOTE: This is not mandatory. Defining it will prevent vectors of
any other size to be added to the embeddings table but, without it,
the embeddings can't be indexed.
collection_name: The name of the collection to use. (default: langchain)
NOTE: This is not the name of the table, but the name of the collection.
The tables will be created when initializing the store (if not exists)
Expand Down Expand Up @@ -183,6 +187,7 @@ def __init__(
self,
connection_string: str,
embedding_function: Embeddings,
embedding_length: Optional[int] = None,
Frank995 marked this conversation as resolved.
Show resolved Hide resolved
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
collection_metadata: Optional[dict] = None,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
Expand All @@ -195,6 +200,7 @@ def __init__(
) -> None:
self.connection_string = connection_string
self.embedding_function = embedding_function
self._embedding_length = embedding_length
self.collection_name = collection_name
self.collection_metadata = collection_metadata
self._distance_strategy = distance_strategy
Expand All @@ -211,7 +217,9 @@ def __post_init__(
"""Initialize the store."""
self.create_vector_extension()

EmbeddingStore, CollectionStore = _get_embedding_collection_store()
EmbeddingStore, CollectionStore = _get_embedding_collection_store(
self._embedding_length
)
self.CollectionStore = CollectionStore
self.EmbeddingStore = EmbeddingStore
self.create_tables_if_not_exists()
Expand Down
Loading