/
qdrant_handler.py
476 lines (408 loc) · 18.6 KB
/
qdrant_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
from collections import OrderedDict
from typing import Any, List, Optional
from itertools import zip_longest
from qdrant_client import QdrantClient, models
import pandas as pd
from mindsdb.integrations.libs.response import HandlerResponse
from mindsdb.integrations.libs.const import HANDLER_CONNECTION_ARG_TYPE as ARG_TYPE
from mindsdb.integrations.libs.response import RESPONSE_TYPE
from mindsdb.integrations.libs.response import HandlerResponse as Response
from mindsdb.integrations.libs.response import HandlerStatusResponse as StatusResponse
from mindsdb.integrations.libs.vectordatabase_handler import (
FilterCondition,
FilterOperator,
TableField,
VectorStoreHandler,
)
from mindsdb.utilities import log
class QdrantHandler(VectorStoreHandler):
"""Handles connection and execution of the Qdrant statements."""
name = "qdrant"
def __init__(self, name: str, **kwargs):
super().__init__(name)
connection_data = kwargs.get("connection_data").copy()
# Qdrant offers several configuration and optmization options at the time of collection creation
# Since the create table statement doesn't have a way to pass these options
# We are requiring the user to pass these options in the connection_data
# These options are documented here. https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection
self.collection_config = connection_data.pop("collection_config")
self.connect(**connection_data)
def connect(self, **kwargs):
"""Connect to a Qdrant instance.
A Qdrant client can be instantiated with a REST, GRPC interface or in-memory for testing.
To use the in-memory instance, specify the location argument as ':memory:'."""
if self.is_connected:
return self._client
try:
self._client = QdrantClient(**kwargs)
self.is_connected = True
return self._client
except Exception as e:
log.logger.error(f"Error instantiating a Qdrant client: {e}")
self.is_connected = False
def disconnect(self):
"""Close the database connection."""
if self.is_connected:
self._client.close()
self._client = None
self.is_connected = False
def check_connection(self) -> StatusResponse:
"""Check the connection to the Qdrant database.
Returns:
StatusResponse: Indicates if the connection is alive
"""
need_to_close = not self.is_connected
try:
# Using a trivial operation to get the connection status
# As there isn't a universal ping method for the REST, GRPC and in-memory interface
self._client.get_locks()
response_code = StatusResponse(True)
except Exception as e:
log.logger.error(f"Error connecting to a Qdrant instance: {e}")
response_code = StatusResponse(False, error_message=str(e))
finally:
if response_code.success and need_to_close:
self.disconnect()
if not response_code.success and self.is_connected:
self.is_connected = False
return response_code
def drop_table(self, table_name: str, if_exists=True) -> HandlerResponse:
"""Delete a collection from the Qdrant Instance.
Args:
table_name (str): The name of the collection to be dropped
if_exists (bool, optional): Throws an error if this value is set to false and the collection doesn't exist. Defaults to True.
Returns:
HandlerResponse: _description_
"""
result = self._client.delete_collection(table_name)
if result or if_exists:
return Response(resp_type=RESPONSE_TYPE.OK)
else:
return Response(
resp_type=RESPONSE_TYPE.ERROR,
error_message=f"Table {table_name} does not exist!",
)
def get_tables(self) -> HandlerResponse:
"""Get the list of collections in the Qdrant instance.
Returns:
HandlerResponse: The common query handler response with a list of table names
"""
collection_response = self._client.get_collections()
collections_name = pd.DataFrame(
columns=["table_name"],
data=[collection.name for collection in collection_response.collections],
)
return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=collections_name)
def get_columns(self, table_name: str) -> HandlerResponse:
try:
_ = self._client.get_collection(table_name)
except ValueError:
return Response(
resp_type=RESPONSE_TYPE.ERROR,
error_message=f"Table {table_name} does not exist!",
)
return super().get_columns(table_name)
def insert(
self, table_name: str, data: pd.DataFrame, columns: List[str] = None
) -> HandlerResponse:
"""Handler for the insert query
Args:
table_name (str): The name of the table to be inserted into
data (pd.DataFrame): The data to be inserted
columns (List[str], optional): Columns to be inserted into. Unused as the values are derived from the "data" argument. Defaults to None.
Returns:
HandlerResponse: The common query handler response
"""
assert len(data[TableField.ID.value]) == len(data[TableField.EMBEDDINGS.value]), "Number of ids and embeddings must be equal"
# Qdrant doesn't have a distinction between documents and metadata
# Any data that is to be stored should be placed in the "payload" field
data = data.to_dict(orient="list")
payloads = []
content_list = data[TableField.CONTENT.value]
metadata_list = data[TableField.METADATA.value]
for document, metadata in zip_longest(content_list, metadata_list, fillvalue=None):
payload = {}
# Insert the document with a "document" key in the payload
if document is not None:
payload["document"] = document
# Unpack all the metadata fields into the payload
if metadata is not None:
payload = {**payload, **metadata}
if payload:
payloads.append(payload)
# IDs can be either integers or strings(UUIDs)
# The following step ensures proper type of numberic values
ids = [int(id) if str(id).isdigit() else id for id in data[TableField.ID.value]]
self._client.upsert(table_name, points=models.Batch(
ids=ids,
vectors=data[TableField.EMBEDDINGS.value],
payloads=payloads
))
return Response(resp_type=RESPONSE_TYPE.OK)
def create_table(self, table_name: str, if_not_exists=True) -> HandlerResponse:
"""Create a collection with the given name in the Qdrant database.
Args:
table_name (str): Name of the table(Collection) to be created
if_not_exists (bool, optional): Throws an error if this value is set to false and the collection already exists. Defaults to True.
Returns:
HandlerResponse: The common query handler response
"""
try:
# Create a collection with the collection name and collection_config set during __init__
self._client.create_collection(table_name, self.collection_config)
except ValueError:
if if_not_exists is False:
return Response(
resp_type=RESPONSE_TYPE.ERROR,
error_message=f"Table {table_name} already exists!",
)
return Response(resp_type=RESPONSE_TYPE.OK)
def _get_qdrant_filter(self, operator: FilterOperator, value: Any) -> dict:
""" Map the filter operator to the Qdrant filter
We use a match and not a dict so as to conditionally construct values
With a dict, all the values the values will be constructed
Generating models.Range() with a str type value fails
Args:
operator (FilterOperator): FilterOperator specified in the query. Eg >=, <=, =
value (Any): Value specified in the query
Raises:
Exception: If an unsupported operator is specified
Returns:
dict: A dict of Qdrant filtering clauses
"""
if operator == FilterOperator.EQUAL:
return {"match": models.MatchValue(value=value)}
elif operator == FilterOperator.NOT_EQUAL:
return {"match": models.MatchExcept(**{"except": [value]})}
elif operator == FilterOperator.LESS_THAN:
return {"range": models.Range(lt=value)}
elif operator == FilterOperator.LESS_THAN_OR_EQUAL:
return {"range": models.Range(lte=value)}
elif operator == FilterOperator.GREATER_THAN:
return {"range": models.Range(gt=value)}
elif operator == FilterOperator.GREATER_THAN_OR_EQUAL:
return {"range": models.Range(gte=value)}
else:
raise Exception(f"Operator {operator} is not supported by Qdrant!")
def _translate_filter_conditions(
self, conditions: List[FilterCondition]
) -> Optional[dict]:
"""
Translate a list of FilterCondition objects a dict that can be used by Qdrant.
Filtering clause docs can be found here: https://qdrant.tech/documentation/concepts/filtering/
E.g.,
[
FilterCondition(
column="metadata.created_at",
op=FilterOperator.LESS_THAN,
value=7132423,
),
FilterCondition(
column="metadata.created_at",
op=FilterOperator.GREATER_THAN,
value=2323432,
)
]
-->
models.Filter(
must=[
models.FieldCondition(
key="created_at",
match=models.Range(lt=7132423),
),
models.FieldCondition(
key="created_at",
match=models.Range(gt=2323432),
),
]
)
"""
# We ignore all non-metadata conditions
if conditions is None:
return None
filter_conditions = [
condition
for condition in conditions
if condition.column.startswith(TableField.METADATA.value)
]
if len(filter_conditions) == 0:
return None
qdrant_filters = []
for condition in filter_conditions:
payload_key = condition.column.split(".")[-1]
qdrant_filters.append(
models.FieldCondition(key=payload_key, **self._get_qdrant_filter(condition.op, condition.value))
)
return models.Filter(must=qdrant_filters) if qdrant_filters else None
def update(
self, table_name: str, data: pd.DataFrame, columns: List[str] = None
) -> HandlerResponse:
"""
Update data in the Qdrant database.
TODO: Update for vector DBs has not been implemented.
Ref: https://github.com/mindsdb/mindsdb/blob/a870ba93b0afee234e48c0268489a94a6e6fd5f7/mindsdb/integrations/libs/vectordatabase_handler.py#L273-L277
"""
return super().update(table_name, data, columns)
def select(self, table_name: str, columns: Optional[List[str]] = None, conditions: Optional[List[FilterCondition]] = None, offset: int = 0, limit: int = 10,) -> HandlerResponse:
"""Select query handler
Eg: SELECT * FROM qdrant.test_table
Args:
table_name (str): The name of the table to be queried
columns (Optional[List[str]], optional): List of column names specified in the query. Defaults to None.
conditions (Optional[List[FilterCondition]], optional): List of "where" conditionals. Defaults to None.
offset (int, optional): Offset the results by the provided value. Defaults to 0.
limit (int, optional): Number of results to return. Defaults to 10.
Returns:
HandlerResponse: The common query handler response
"""
# Validate and set offset and limit as None is passed if not set in the query
offset = offset if offset is not None else 0
limit = limit if limit is not None else 10
# Full scroll if no where conditions are specified
if not conditions:
results = self._client.scroll(table_name, limit=limit, offset=offset)
payload = self._process_select_results(results[0], columns)
return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=payload)
# Filter conditions
vector_filter = [condition.value for condition in conditions if condition.column == TableField.SEARCH_VECTOR.value]
id_filters = [condition.value for condition in conditions if condition.column == TableField.ID.value]
query_filters = self._translate_filter_conditions(conditions)
# Prefer returning results by IDs first
if id_filters:
results = self._client.retrieve(table_name, ids=id_filters)
# Followed by the search_vector value
elif vector_filter:
# Perform a similarity search with the first vector filter
results = self._client.search(table_name, query_vector=vector_filter[0], query_filter=query_filters or None, limit=limit, offset=offset)
elif query_filters:
results = self._client.scroll(table_name, scroll_filter=query_filters, limit=limit, offset=offset)[0]
# Process results
payload = self._process_select_results(results, columns)
return Response(resp_type=RESPONSE_TYPE.TABLE, data_frame=payload)
def _process_select_results(self, results=None, columns=None):
"""Private method to process the results of a select query
Args:
results: A List[Records] or List[ScoredPoint]. Defaults to None
columns: List of column names specified in the query. Defaults to None
Returns:
Dataframe: A processed pandas dataframe
"""
ids, documents, metadata, distances = [], [], [], []
for result in results:
ids.append(result.id)
# The documents and metadata are stored as a dict in the payload
documents.append(result.payload["document"])
metadata.append({k: v for k, v in result.payload.items() if k != "document"})
# Score is only available for similarity search results
if "score" in result:
distances.append(result.score)
payload = {
TableField.ID.value: ids,
TableField.CONTENT.value: documents,
TableField.METADATA.value: metadata,
}
# Filter result columns
if columns:
payload = {
column: payload[column]
for column in columns
if column != TableField.EMBEDDINGS.value and column in payload
}
# If the distance list is empty, don't add it to the result
if distances:
payload[TableField.DISTANCE.value] = distances
return pd.DataFrame(payload)
def delete(
self, table_name: str, conditions: List[FilterCondition] = None
) -> HandlerResponse:
"""Delete query handler
Args:
table_name (str): List of column names specified in the query. Defaults to None.
conditions (List[FilterCondition], optional): List of "where" conditionals. Defaults to None.
Raises:
Exception: If no conditions are specified
Returns:
HandlerResponse: The common query handler response
"""
filters = self._translate_filter_conditions(conditions)
# Get id filters
ids = [
condition.value
for condition in conditions
if condition.column == TableField.ID.value
] or None
if filters is None and ids is None:
raise Exception("Delete query must have at least one condition!")
if ids:
self._client.delete(table_name, points_selector=models.PointIdsList(points=ids))
if filters:
self._client.delete(table_name, points_selector=models.FilterSelector(filter=filters))
return Response(resp_type=RESPONSE_TYPE.OK)
connection_args = OrderedDict(
location={
"type": ARG_TYPE.STR,
"description": "If `:memory:` - use in-memory Qdrant instance. If a remote URL - connect to a remote Qdrant instance. Example: `http://localhost:6333`",
"required": False,
},
url={
"type": ARG_TYPE.STR,
"description": "URL of Qdrant service. Either host or a string of type [scheme]<host><[port][prefix]. Ex: http://localhost:6333/service/v1",
},
host={
"type": ARG_TYPE.STR,
"description": "Host name of Qdrant service. The port and host are used to construct the connection URL.",
"required": False,
},
port={
"type": ARG_TYPE.INT,
"description": "Port of the REST API interface. Default: 6333",
"required": False,
},
grpc_port={
"type": ARG_TYPE.INT,
"description": "Port of the gRPC interface. Default: 6334",
"required": False,
},
prefer_grpc={
"type": ARG_TYPE.BOOL,
"description": "If `true` - use gPRC interface whenever possible in custom methods. Default: false",
"required": False,
},
https={
"type": ARG_TYPE.BOOL,
"description": "If `true` - use https protocol.",
"required": False,
},
api_key={
"type": ARG_TYPE.STR,
"description": "API key for authentication in Qdrant Cloud.",
"required": False,
},
prefix={
"type": ARG_TYPE.STR,
"description": "If set, the value is added to the REST URL path. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API",
"required": False,
},
timeout={
"type": ARG_TYPE.INT,
"description": "Timeout for REST and gRPC API requests. Defaults to 5.0 seconds for REST and unlimited for gRPC",
"required": False,
},
path={
"type": ARG_TYPE.STR,
"description": "Persistence path for a local Qdrant instance(:memory:).",
"required": False,
},
collection_config={
"type": ARG_TYPE.DICT,
"description": "Collection creation configuration. See https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection",
"required": True,
},
)
connection_args_example = {
"location": ":memory:",
"collection_config": {
"size": 386,
"distance": "Cosine"
}
}