Skip to content

Commit 602a86a

Browse files
authored
Add workspace client as optional arg to dspy.DatabricksRM (#174)
1 parent c9b48dc commit 602a86a

File tree

2 files changed

+153
-53
lines changed

2 files changed

+153
-53
lines changed

integrations/dspy/src/databricks_dspy/retrievers/databricks_rm.py

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any
66

77
import dspy
8+
from databricks.sdk import WorkspaceClient
89
from dspy.primitives.prediction import Prediction
910

1011
logger = logging.getLogger(__name__)
@@ -38,6 +39,10 @@ class DatabricksRM(dspy.Retrieve):
3839
3940
```python
4041
from databricks.vector_search.client import VectorSearchClient
42+
from databricks.sdk import WorkspaceClient
43+
44+
# Create a Databricks workspace client
45+
w = WorkspaceClient()
4146
4247
# Create a Databricks Vector Search Endpoint
4348
client = VectorSearchClient()
@@ -65,6 +70,7 @@ class DatabricksRM(dspy.Retrieve):
6570
docs_id_column_name="id",
6671
text_column_name="field2",
6772
k=3,
73+
workspace_client=w,
6874
)
6975
```
7076
@@ -90,6 +96,7 @@ def __init__(
9096
docs_uri_column_name: str | None = None,
9197
text_column_name: str = "text",
9298
use_with_databricks_agent_framework: bool = False,
99+
workspace_client: WorkspaceClient | None = None,
93100
):
94101
"""
95102
Args:
@@ -122,6 +129,8 @@ def __init__(
122129
containing document text to retrieve.
123130
use_with_databricks_agent_framework (bool): Whether to use the `DatabricksRM` in a way
124131
that is compatible with the Databricks Mosaic Agent Framework.
132+
workspace_client (Optional[WorkspaceClient]): The workspace client to use. If not
133+
provided, a new one will be created with default credentials from the environment.
125134
"""
126135
super().__init__(k=k)
127136
self.databricks_token = databricks_token or os.environ.get("DATABRICKS_TOKEN")
@@ -154,6 +163,32 @@ def __init__(
154163
"library. Please install mlflow via `pip install mlflow`."
155164
) from None
156165

166+
# Use provided workspace client or create one based on credentials
167+
if workspace_client:
168+
self.workspace_client = workspace_client
169+
elif databricks_client_secret and databricks_client_id:
170+
# Use client ID and secret for authentication if they are provided
171+
self.workspace_client = WorkspaceClient(
172+
client_id=databricks_client_id,
173+
client_secret=databricks_client_secret,
174+
)
175+
logger.info(
176+
"Creating Databricks workspace client using service principal authentication."
177+
)
178+
elif databricks_token and databricks_endpoint:
179+
# token-based authentication
180+
self.workspace_client = WorkspaceClient(
181+
host=databricks_endpoint,
182+
token=databricks_token,
183+
)
184+
logger.info("Creating Databricks workspace client using token authentication.")
185+
else:
186+
# fallback to default authentication, i.e., using `~/.databrickscfg` file.
187+
self.workspace_client = WorkspaceClient()
188+
logger.info(
189+
"Creating Databricks workspace client using credentials from `~/.databrickscfg` file."
190+
)
191+
157192
def _extract_doc_ids(self, item: dict[str, Any]) -> str:
158193
"""Extracts the document id from a search result
159194
@@ -237,10 +272,6 @@ def forward(
237272
query_type=query_type,
238273
query_text=query_text,
239274
query_vector=query_vector,
240-
databricks_token=self.databricks_token,
241-
databricks_endpoint=self.databricks_endpoint,
242-
databricks_client_id=self.databricks_client_id,
243-
databricks_client_secret=self.databricks_client_secret,
244275
filters_json=filters_json or self.filters_json,
245276
)
246277

@@ -305,15 +336,10 @@ def _query_vector_search_index(
305336
query_type: str,
306337
query_text: str | None,
307338
query_vector: list[float] | None,
308-
databricks_token: str | None,
309-
databricks_endpoint: str | None,
310-
databricks_client_id: str | None,
311-
databricks_client_secret: str | None,
312339
filters_json: str | None,
313340
) -> dict[str, Any]:
314341
"""
315342
Query a Databricks Vector Search Index via the Databricks SDK.
316-
Assumes that the databricks-sdk Python library is installed.
317343
318344
Args:
319345
index_name (str): Name of the Databricks vector search index to query
@@ -324,49 +350,14 @@ def _query_vector_search_index(
324350
query_vector (Optional[list[float]]): Numeric query vector for which to find relevant
325351
documents. Exactly one of query_text or query_vector must be specified.
326352
filters_json (Optional[str]): JSON string representing additional query filters.
327-
databricks_token (str): Databricks authentication token. If not specified,
328-
the token is resolved from the current environment.
329-
databricks_endpoint (str): Databricks index endpoint url. If not specified,
330-
the endpoint is resolved from the current environment.
331-
databricks_client_id (str): Databricks service principal id. If not specified,
332-
the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
333-
databricks_client_secret (str): Databricks service principal secret. If not specified,
334-
the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).
335353
336354
Returns:
337355
dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query.
338356
"""
339-
340-
from databricks.sdk import WorkspaceClient
341-
342357
if (query_text, query_vector).count(None) != 1:
343358
raise ValueError("Exactly one of query_text or query_vector must be specified.")
344359

345-
if databricks_client_secret and databricks_client_id:
346-
# Use client ID and secret for authentication if they are provided
347-
databricks_client = WorkspaceClient(
348-
client_id=databricks_client_id,
349-
client_secret=databricks_client_secret,
350-
)
351-
logger.info(
352-
"Creating Databricks workspace client using service principal authentication."
353-
)
354-
355-
elif databricks_token and databricks_endpoint:
356-
# token-based authentication
357-
databricks_client = WorkspaceClient(
358-
host=databricks_endpoint,
359-
token=databricks_token,
360-
)
361-
logger.info("Creating Databricks workspace client using token authentication.")
362-
else:
363-
# fallback to default authentication, i.e., using `~/.databrickscfg` file.
364-
databricks_client = WorkspaceClient()
365-
logger.info(
366-
"Creating Databricks workspace client using credentials from `~/.databrickscfg` file."
367-
)
368-
369-
return databricks_client.vector_search_indexes.query_index(
360+
return self.workspace_client.vector_search_indexes.query_index(
370361
index_name=index_name,
371362
query_type=query_type,
372363
query_text=query_text,

integrations/dspy/tests/unit_tests/retrievers/test_databricks_rm.py

Lines changed: 117 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def mock_vector_search_response_with_uri():
4949
}
5050

5151

52-
@patch("databricks.sdk.WorkspaceClient")
52+
@patch("databricks_dspy.retrievers.databricks_rm.WorkspaceClient")
5353
def test_databricks_rm_forward_string_query(mock_workspace_client, mock_vector_search_response):
5454
"""Test forward method with string query and ANN search."""
5555
mock_client = MagicMock()
@@ -84,7 +84,7 @@ def test_databricks_rm_forward_string_query(mock_workspace_client, mock_vector_s
8484
assert result.doc_ids[0] == "doc1"
8585

8686

87-
@patch("databricks.sdk.WorkspaceClient")
87+
@patch("databricks_dspy.retrievers.databricks_rm.WorkspaceClient")
8888
def test_databricks_rm_forward_vector_query(mock_workspace_client, mock_vector_search_response):
8989
"""Test forward method with vector query and HYBRID search."""
9090
mock_client = MagicMock()
@@ -107,7 +107,7 @@ def test_databricks_rm_forward_vector_query(mock_workspace_client, mock_vector_s
107107
assert set(call_args["columns"]) == {"id", "text"}
108108

109109

110-
@patch("databricks.sdk.WorkspaceClient")
110+
@patch("databricks_dspy.retrievers.databricks_rm.WorkspaceClient")
111111
def test_databricks_rm_agent_framework_format(
112112
mock_workspace_client, mock_vector_search_response_with_uri
113113
):
@@ -138,8 +138,12 @@ def test_databricks_rm_agent_framework_format(
138138
assert doc["type"] == "Document"
139139

140140

141-
def test_databricks_rm_initialization():
141+
@patch("databricks_dspy.retrievers.databricks_rm.WorkspaceClient")
142+
def test_databricks_rm_initialization(mock_workspace_client):
142143
"""Test initialization with token authentication."""
144+
mock_client = MagicMock()
145+
mock_workspace_client.return_value = mock_client
146+
143147
rm = DatabricksRM(
144148
databricks_index_name="test_index",
145149
databricks_endpoint="https://test.databricks.com",
@@ -155,8 +159,75 @@ def test_databricks_rm_initialization():
155159
assert rm.text_column_name == "text"
156160
assert not rm.use_with_databricks_agent_framework
157161

162+
# Verify WorkspaceClient was created with token auth
163+
mock_workspace_client.assert_called_once_with(
164+
host="https://test.databricks.com",
165+
token="test_token",
166+
)
167+
# Workspace client should be set
168+
assert rm.workspace_client == mock_client
169+
170+
171+
def test_databricks_rm_initialization_with_custom_workspace_client():
172+
"""Test initialization with custom workspace_client."""
173+
mock_workspace_client = MagicMock()
174+
175+
rm = DatabricksRM(
176+
databricks_index_name="test_index",
177+
workspace_client=mock_workspace_client,
178+
k=5,
179+
)
180+
181+
assert rm.databricks_index_name == "test_index"
182+
assert rm.workspace_client == mock_workspace_client
183+
assert rm.k == 5
184+
assert rm.docs_id_column_name == "id"
185+
assert rm.text_column_name == "text"
186+
assert not rm.use_with_databricks_agent_framework
187+
188+
189+
def test_databricks_rm_query_with_custom_workspace_client():
190+
"""Test that custom workspace_client is used for queries."""
191+
mock_workspace_client = MagicMock()
192+
193+
mock_response = {
194+
"result": {
195+
"data_array": [
196+
["doc1", "This is document 1", 0.95],
197+
]
198+
},
199+
"manifest": {
200+
"columns": [
201+
{"name": "id"},
202+
{"name": "text"},
203+
{"name": "score"},
204+
]
205+
},
206+
}
207+
mock_workspace_client.vector_search_indexes.query_index.return_value.as_dict.return_value = (
208+
mock_response
209+
)
158210

159-
@patch("databricks.sdk.WorkspaceClient")
211+
rm = DatabricksRM(
212+
databricks_index_name="test_index",
213+
workspace_client=mock_workspace_client,
214+
)
215+
216+
result = rm("test query")
217+
218+
# Verify that the custom workspace_client was used for the query
219+
mock_workspace_client.vector_search_indexes.query_index.assert_called_once()
220+
call_args = mock_workspace_client.vector_search_indexes.query_index.call_args[1]
221+
assert call_args["index_name"] == "test_index"
222+
assert call_args["query_text"] == "test query"
223+
224+
# Verify results
225+
assert len(result.docs) == 1
226+
assert result.docs[0] == "This is document 1"
227+
assert result.doc_ids[0] == "doc1"
228+
229+
230+
@patch("databricks_dspy.retrievers.databricks_rm.WorkspaceClient")
160231
def test_databricks_rm_service_principal_auth(mock_workspace_client, mock_vector_search_response):
161232
"""Test querying with service principal authentication."""
162233
mock_client = MagicMock()
@@ -180,15 +251,19 @@ def test_databricks_rm_service_principal_auth(mock_workspace_client, mock_vector
180251
)
181252

182253

183-
def test_databricks_rm_invalid_query_type():
254+
@patch("databricks_dspy.retrievers.databricks_rm.WorkspaceClient")
255+
def test_databricks_rm_invalid_query_type(mock_workspace_client):
184256
"""Test forward method with invalid query type."""
257+
mock_client = MagicMock()
258+
mock_workspace_client.return_value = mock_client
259+
185260
rm = DatabricksRM(databricks_index_name="test_index")
186261

187262
with pytest.raises(ValueError, match="Invalid query_type: INVALID"):
188263
rm("test query", query_type="INVALID")
189264

190265

191-
@patch("databricks.sdk.WorkspaceClient")
266+
@patch("databricks_dspy.retrievers.databricks_rm.WorkspaceClient")
192267
def test_databricks_rm_missing_column_error(mock_workspace_client):
193268
"""Test error when ID column is missing from index."""
194269
mock_client = MagicMock()
@@ -210,7 +285,7 @@ def test_databricks_rm_missing_column_error(mock_workspace_client):
210285
rm("test query")
211286

212287

213-
@patch("databricks.sdk.WorkspaceClient")
288+
@patch("databricks_dspy.retrievers.databricks_rm.WorkspaceClient")
214289
def test_databricks_rm_result_sorting(mock_workspace_client):
215290
"""Test that results are sorted by score in descending order."""
216291
mock_client = MagicMock()
@@ -236,3 +311,37 @@ def test_databricks_rm_result_sorting(mock_workspace_client):
236311
# Should be sorted by score (highest first)
237312
assert result.doc_ids == ["doc1", "doc3", "doc2"]
238313
assert result.docs == ["Document 1", "Document 3", "Document 2"]
314+
315+
316+
@patch("databricks_dspy.retrievers.databricks_rm.WorkspaceClient")
317+
def test_databricks_rm_query_with_invalid_credentials(mock_workspace_client):
318+
"""Test that authentication errors are raised during query execution."""
319+
mock_client = MagicMock()
320+
mock_workspace_client.return_value = mock_client
321+
# Simulate authentication error when querying the index
322+
mock_client.vector_search_indexes.query_index.side_effect = Exception("Authentication failed")
323+
324+
rm = DatabricksRM(
325+
databricks_index_name="test_index",
326+
databricks_token="invalid_token",
327+
databricks_endpoint="https://test.databricks.com",
328+
)
329+
330+
# Error occurs when trying to query, not during initialization
331+
with pytest.raises(Exception, match="Authentication failed"):
332+
rm("test query")
333+
334+
335+
@patch("databricks_dspy.retrievers.databricks_rm.WorkspaceClient")
336+
def test_databricks_rm_fallback_to_default_auth(mock_workspace_client):
337+
"""Test fallback to default authentication when no credentials provided."""
338+
mock_client = MagicMock()
339+
mock_workspace_client.return_value = mock_client
340+
341+
rm = DatabricksRM(databricks_index_name="test_index")
342+
343+
assert rm.databricks_index_name == "test_index"
344+
345+
# Verify WorkspaceClient was created with no auth params (default auth)
346+
mock_workspace_client.assert_called_once_with()
347+
assert rm.workspace_client == mock_client

0 commit comments

Comments
 (0)