# Databricks LangChain Embeddings Demo

This notebook shows how to attach vector embeddings to a Spark DataFrame using `spark_fuse.utils.llm.with_langchain_embeddings` together with LangChain's `DatabricksEmbeddings`.

## Prerequisites
- Packages: `spark-fuse`, `pyspark`, `langchain-core`, `langchain-community`, `langchain-text-splitters`.
- Databricks credentials: set `DATABRICKS_HOST` and `DATABRICKS_TOKEN` (or configure a Databricks profile).
- A served embedding endpoint name such as `databricks-bge-large-en` (Model Serving) or another model your workspace exposes.

To avoid external calls while prototyping, switch to the stub embeddings class at the bottom of the notebook.


In [None]:
from databricks_langchain import DatabricksEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from spark_fuse.spark import create_session
from spark_fuse.utils.llm import with_langchain_embeddings

spark = create_session(app_name="spark-fuse-dbx-embeddings-demo", master="local[2]")


## Sample data

We will embed a small set of product descriptions.

In [None]:
data = [
    {"id": 1, "text": "A lightweight hiking backpack with 20L capacity."},
    {"id": 2, "text": "Insulated stainless steel water bottle, 750ml."},
    {"id": 3, "text": "Breathable running shoes for road and trail."},
]
df = spark.createDataFrame(data)
df.show(truncate=False)

## Embed with DatabricksEmbeddings

Use a factory (`lambda: DatabricksEmbeddings(...)`) so each executor initializes its own client. Configure `DATABRICKS_HOST`/`DATABRICKS_TOKEN` and point to your embedding endpoint.


In [None]:
embedded = with_langchain_embeddings(
    df,
    input_col="text",
    embeddings=lambda: DatabricksEmbeddings(endpoint="databricks-bge-large-en"),
    output_col="embedding",
    batch_size=8,
)

embedded.select("id", "embedding").show(3, truncate=False)


## Chunk long documents with a text splitter

When documents are long, add a LangChain splitter. Chunk embeddings are combined with the chosen aggregation strategy (`mean` below).


In [None]:
splitter = RecursiveCharacterTextSplitter(chunk_size=64, chunk_overlap=16)

split_embedded = with_langchain_embeddings(
    df,
    input_col="text",
    embeddings=lambda: DatabricksEmbeddings(endpoint="databricks-bge-large-en"),
    text_splitter=splitter,
    aggregation="mean",
    output_col="embedding_mean",
    batch_size=8,
)

split_embedded.select("id", "embedding_mean").show(3, truncate=False)


## Offline or cost-free testing with a stub

For quick validation without network calls, define a minimal embeddings class. This keeps the schema and workflow identical while producing deterministic vectors.

In [None]:
class StubEmbeddings:
    """Deterministic hash-based embeddings for local testing."""

    def embed_documents(self, texts):
        def _vec(text):
            # Map length and simple checksum into a tiny vector for demonstration only.
            length = float(len(text))
            checksum = float(sum(ord(ch) for ch in text) % 97)
            return [length, checksum]

        return [_vec(t) for t in texts]


stubbed = with_langchain_embeddings(
    df,
    input_col="text",
    embeddings=StubEmbeddings(),
    output_col="embedding_stub",
    batch_size=4,
)

stubbed.select("id", "embedding_stub").show(truncate=False)

### Cleanup

Stop the SparkSession when finished.

In [None]:
spark.stop()