Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: DuckDB VS - expose similarity, improve performance of from_texts #20971

Merged
merged 11 commits into from
May 24, 2024
8 changes: 4 additions & 4 deletions docs/docs/integrations/vectorstores/duckdb.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"metadata": {},
"outputs": [],
"source": [
"! pip install duckdb langchain-community"
"! pip install duckdb langchain langchain-community langchain-openai"
]
},
{
Expand Down Expand Up @@ -86,7 +86,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -100,9 +100,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
74 changes: 56 additions & 18 deletions libs/community/langchain_community/vectorstores/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@
from __future__ import annotations

import json
import logging
import uuid
import warnings
from typing import Any, Iterable, List, Optional, Type

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore

logger = logging.getLogger(__name__)

DEFAULT_VECTOR_KEY = "embedding"
DEFAULT_ID_KEY = "id"
DEFAULT_TEXT_KEY = "text"
DEFAULT_TABLE_NAME = "embeddings"
SIMILARITY_ALIAS = "similarity_score"


class DuckDB(VectorStore):
"""`DuckDB` vector store.
Expand Down Expand Up @@ -76,10 +86,10 @@ def __init__(
*,
connection: Optional[Any] = None,
embedding: Embeddings,
vector_key: str = "embedding",
id_key: str = "id",
text_key: str = "text",
table_name: str = "vectorstore",
vector_key: str = DEFAULT_VECTOR_KEY,
id_key: str = DEFAULT_ID_KEY,
text_key: str = DEFAULT_TEXT_KEY,
table_name: str = DEFAULT_TABLE_NAME,
jaceksan marked this conversation as resolved.
Show resolved Hide resolved
):
"""Initialize with DuckDB connection and setup for vector storage."""
try:
Expand All @@ -100,8 +110,6 @@ def __init__(
raise ValueError("An embedding function or model must be provided.")

if connection is None:
import warnings

warnings.warn(
"No DuckDB connection provided. A new connection will be created."
"This connection is running in memory and no data will be persisted."
Expand Down Expand Up @@ -138,13 +146,25 @@ def add_texts(
Returns:
List of ids of the added texts.
"""
have_pandas = False
try:
import pandas as pd

have_pandas = True
except ImportError:
logger.info(
"Unable to import pandas. "
"Install it with `pip install -U pandas` "
"to improve performance of add_texts()."
)

# Extract ids from kwargs or generate new ones if not provided
ids = kwargs.pop("ids", [str(uuid.uuid4()) for _ in texts])

# Embed texts and create documents
ids = ids or [str(uuid.uuid4()) for _ in texts]
embeddings = self._embedding.embed_documents(list(texts))
data = []
for idx, text in enumerate(texts):
embedding = embeddings[idx]
# Serialize metadata if present, else default to None
Expand All @@ -153,9 +173,26 @@ def add_texts(
if metadatas and idx < len(metadatas)
else None
)
if have_pandas:
data.append(
{
self._id_key: ids[idx],
self._text_key: text,
self._vector_key: embedding,
"metadata": metadata,
}
)
else:
self._connection.execute(
f"INSERT INTO {self._table_name} VALUES (?,?,?,?)",
[ids[idx], text, embedding, metadata],
)

if have_pandas:
# noinspection PyUnusedLocal
df = pd.DataFrame.from_dict(data) # noqa: F841
self._connection.execute(
f"INSERT INTO {self._table_name} VALUES (?,?,?,?)",
[ids[idx], text, embedding, metadata],
f"INSERT INTO {self._table_name} SELECT * FROM df",
)
return ids

Expand All @@ -181,20 +218,21 @@ def similarity_search(
self._table.select(
*[
self.duckdb.StarExpression(exclude=[]),
list_cosine_similarity.alias("similarity"),
list_cosine_similarity.alias(SIMILARITY_ALIAS),
]
)
.order("similarity desc")
.order(f"{SIMILARITY_ALIAS} desc")
.limit(k)
.select(
self.duckdb.StarExpression(exclude=["similarity", self._vector_key])
)
.fetchdf()
)
return [
Document(
page_content=docs[self._text_key][idx],
metadata=json.loads(docs["metadata"][idx])
metadata={
**json.loads(docs["metadata"][idx]),
# using underscore prefix to avoid conflicts with user metadata keys
f"_{SIMILARITY_ALIAS}": docs[SIMILARITY_ALIAS][idx],
}
if docs["metadata"][idx]
else {},
)
Expand Down Expand Up @@ -231,10 +269,10 @@ def from_texts(

# Extract kwargs for DuckDB instance creation
connection = kwargs.get("connection", None)
vector_key = kwargs.get("vector_key", "vector")
id_key = kwargs.get("id_key", "id")
text_key = kwargs.get("text_key", "text")
table_name = kwargs.get("table_name", "embeddings")
vector_key = kwargs.get("vector_key", DEFAULT_VECTOR_KEY)
id_key = kwargs.get("id_key", DEFAULT_ID_KEY)
text_key = kwargs.get("text_key", DEFAULT_TEXT_KEY)
table_name = kwargs.get("table_name", DEFAULT_TABLE_NAME)

# Create an instance of DuckDB
instance = DuckDB(
Expand Down