@@ -49,7 +49,7 @@ def mock_vector_search_response_with_uri():
49
49
}
50
50
51
51
52
- @patch ("databricks.sdk .WorkspaceClient" )
52
+ @patch ("databricks_dspy.retrievers.databricks_rm .WorkspaceClient" )
53
53
def test_databricks_rm_forward_string_query (mock_workspace_client , mock_vector_search_response ):
54
54
"""Test forward method with string query and ANN search."""
55
55
mock_client = MagicMock ()
@@ -84,7 +84,7 @@ def test_databricks_rm_forward_string_query(mock_workspace_client, mock_vector_s
84
84
assert result .doc_ids [0 ] == "doc1"
85
85
86
86
87
- @patch ("databricks.sdk .WorkspaceClient" )
87
+ @patch ("databricks_dspy.retrievers.databricks_rm .WorkspaceClient" )
88
88
def test_databricks_rm_forward_vector_query (mock_workspace_client , mock_vector_search_response ):
89
89
"""Test forward method with vector query and HYBRID search."""
90
90
mock_client = MagicMock ()
@@ -107,7 +107,7 @@ def test_databricks_rm_forward_vector_query(mock_workspace_client, mock_vector_s
107
107
assert set (call_args ["columns" ]) == {"id" , "text" }
108
108
109
109
110
- @patch ("databricks.sdk .WorkspaceClient" )
110
+ @patch ("databricks_dspy.retrievers.databricks_rm .WorkspaceClient" )
111
111
def test_databricks_rm_agent_framework_format (
112
112
mock_workspace_client , mock_vector_search_response_with_uri
113
113
):
@@ -138,8 +138,12 @@ def test_databricks_rm_agent_framework_format(
138
138
assert doc ["type" ] == "Document"
139
139
140
140
141
- def test_databricks_rm_initialization ():
141
+ @patch ("databricks_dspy.retrievers.databricks_rm.WorkspaceClient" )
142
+ def test_databricks_rm_initialization (mock_workspace_client ):
142
143
"""Test initialization with token authentication."""
144
+ mock_client = MagicMock ()
145
+ mock_workspace_client .return_value = mock_client
146
+
143
147
rm = DatabricksRM (
144
148
databricks_index_name = "test_index" ,
145
149
databricks_endpoint = "https://test.databricks.com" ,
@@ -155,8 +159,75 @@ def test_databricks_rm_initialization():
155
159
assert rm .text_column_name == "text"
156
160
assert not rm .use_with_databricks_agent_framework
157
161
162
+ # Verify WorkspaceClient was created with token auth
163
+ mock_workspace_client .assert_called_once_with (
164
+ host = "https://test.databricks.com" ,
165
+ token = "test_token" ,
166
+ )
167
+ # Workspace client should be set
168
+ assert rm .workspace_client == mock_client
169
+
170
+
171
+ def test_databricks_rm_initialization_with_custom_workspace_client ():
172
+ """Test initialization with custom workspace_client."""
173
+ mock_workspace_client = MagicMock ()
174
+
175
+ rm = DatabricksRM (
176
+ databricks_index_name = "test_index" ,
177
+ workspace_client = mock_workspace_client ,
178
+ k = 5 ,
179
+ )
180
+
181
+ assert rm .databricks_index_name == "test_index"
182
+ assert rm .workspace_client == mock_workspace_client
183
+ assert rm .k == 5
184
+ assert rm .docs_id_column_name == "id"
185
+ assert rm .text_column_name == "text"
186
+ assert not rm .use_with_databricks_agent_framework
187
+
188
+
189
+ def test_databricks_rm_query_with_custom_workspace_client ():
190
+ """Test that custom workspace_client is used for queries."""
191
+ mock_workspace_client = MagicMock ()
192
+
193
+ mock_response = {
194
+ "result" : {
195
+ "data_array" : [
196
+ ["doc1" , "This is document 1" , 0.95 ],
197
+ ]
198
+ },
199
+ "manifest" : {
200
+ "columns" : [
201
+ {"name" : "id" },
202
+ {"name" : "text" },
203
+ {"name" : "score" },
204
+ ]
205
+ },
206
+ }
207
+ mock_workspace_client .vector_search_indexes .query_index .return_value .as_dict .return_value = (
208
+ mock_response
209
+ )
158
210
159
- @patch ("databricks.sdk.WorkspaceClient" )
211
+ rm = DatabricksRM (
212
+ databricks_index_name = "test_index" ,
213
+ workspace_client = mock_workspace_client ,
214
+ )
215
+
216
+ result = rm ("test query" )
217
+
218
+ # Verify that the custom workspace_client was used for the query
219
+ mock_workspace_client .vector_search_indexes .query_index .assert_called_once ()
220
+ call_args = mock_workspace_client .vector_search_indexes .query_index .call_args [1 ]
221
+ assert call_args ["index_name" ] == "test_index"
222
+ assert call_args ["query_text" ] == "test query"
223
+
224
+ # Verify results
225
+ assert len (result .docs ) == 1
226
+ assert result .docs [0 ] == "This is document 1"
227
+ assert result .doc_ids [0 ] == "doc1"
228
+
229
+
230
+ @patch ("databricks_dspy.retrievers.databricks_rm.WorkspaceClient" )
160
231
def test_databricks_rm_service_principal_auth (mock_workspace_client , mock_vector_search_response ):
161
232
"""Test querying with service principal authentication."""
162
233
mock_client = MagicMock ()
@@ -180,15 +251,19 @@ def test_databricks_rm_service_principal_auth(mock_workspace_client, mock_vector
180
251
)
181
252
182
253
183
- def test_databricks_rm_invalid_query_type ():
254
+ @patch ("databricks_dspy.retrievers.databricks_rm.WorkspaceClient" )
255
+ def test_databricks_rm_invalid_query_type (mock_workspace_client ):
184
256
"""Test forward method with invalid query type."""
257
+ mock_client = MagicMock ()
258
+ mock_workspace_client .return_value = mock_client
259
+
185
260
rm = DatabricksRM (databricks_index_name = "test_index" )
186
261
187
262
with pytest .raises (ValueError , match = "Invalid query_type: INVALID" ):
188
263
rm ("test query" , query_type = "INVALID" )
189
264
190
265
191
- @patch ("databricks.sdk .WorkspaceClient" )
266
+ @patch ("databricks_dspy.retrievers.databricks_rm .WorkspaceClient" )
192
267
def test_databricks_rm_missing_column_error (mock_workspace_client ):
193
268
"""Test error when ID column is missing from index."""
194
269
mock_client = MagicMock ()
@@ -210,7 +285,7 @@ def test_databricks_rm_missing_column_error(mock_workspace_client):
210
285
rm ("test query" )
211
286
212
287
213
- @patch ("databricks.sdk .WorkspaceClient" )
288
+ @patch ("databricks_dspy.retrievers.databricks_rm .WorkspaceClient" )
214
289
def test_databricks_rm_result_sorting (mock_workspace_client ):
215
290
"""Test that results are sorted by score in descending order."""
216
291
mock_client = MagicMock ()
@@ -236,3 +311,37 @@ def test_databricks_rm_result_sorting(mock_workspace_client):
236
311
# Should be sorted by score (highest first)
237
312
assert result .doc_ids == ["doc1" , "doc3" , "doc2" ]
238
313
assert result .docs == ["Document 1" , "Document 3" , "Document 2" ]
314
+
315
+
316
+ @patch ("databricks_dspy.retrievers.databricks_rm.WorkspaceClient" )
317
+ def test_databricks_rm_query_with_invalid_credentials (mock_workspace_client ):
318
+ """Test that authentication errors are raised during query execution."""
319
+ mock_client = MagicMock ()
320
+ mock_workspace_client .return_value = mock_client
321
+ # Simulate authentication error when querying the index
322
+ mock_client .vector_search_indexes .query_index .side_effect = Exception ("Authentication failed" )
323
+
324
+ rm = DatabricksRM (
325
+ databricks_index_name = "test_index" ,
326
+ databricks_token = "invalid_token" ,
327
+ databricks_endpoint = "https://test.databricks.com" ,
328
+ )
329
+
330
+ # Error occurs when trying to query, not during initialization
331
+ with pytest .raises (Exception , match = "Authentication failed" ):
332
+ rm ("test query" )
333
+
334
+
335
+ @patch ("databricks_dspy.retrievers.databricks_rm.WorkspaceClient" )
336
+ def test_databricks_rm_fallback_to_default_auth (mock_workspace_client ):
337
+ """Test fallback to default authentication when no credentials provided."""
338
+ mock_client = MagicMock ()
339
+ mock_workspace_client .return_value = mock_client
340
+
341
+ rm = DatabricksRM (databricks_index_name = "test_index" )
342
+
343
+ assert rm .databricks_index_name == "test_index"
344
+
345
+ # Verify WorkspaceClient was created with no auth params (default auth)
346
+ mock_workspace_client .assert_called_once_with ()
347
+ assert rm .workspace_client == mock_client
0 commit comments