-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathhello.py
executable file
·97 lines (78 loc) · 2.57 KB
/
hello.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python
import os
from urllib.parse import urljoin
from httpx import Client, Response
def log_request(request):
print(f"> {request.method} {request.url}")
def log_response(response):
request = response.request
print(f"< {request.method} {request.url} - {response.status_code}")
if response.status_code >= 299:
response.read()
print(f"\n{response.text}")
def main():
# HTTP setup
endpoint = os.getenv("ENDPOINT", 'http://localhost:8000')
api_url = urljoin(endpoint, '/api/v1/')
client = Client(
event_hooks={
"request": [log_request],
"response": [log_response, Response.raise_for_status],
},
)
headers = {
"Accept": "application/json; charset=utf-8",
"Content-Type": "application/json; charset=utf-8",
}
version = client.get(urljoin(api_url, 'version'), headers=headers).json()
print(f"Chroma {version}")
# check whether a collection exists
collection_name = "my-collection"
collections = client.get(urljoin(api_url, "collections"), headers=headers).json()
collection = next((x for x in collections if x["name"] == collection_name), None)
if not collection:
collection = client.post(
urljoin(api_url, "collections"),
headers=headers,
json={
"name": collection_name
},
).json()
# index data
vectors = [
{
"id": "d8f940f1-d6c1-4d8e-82c1-488eb7801e57",
"values": [0.1, 0.2, 0.3],
"metadata": {"genre": "drama"},
},
{
"id": "c47eade8-59b9-4c49-9172-a0ce3d9dd0af",
"values": [0.2, 0.3, 0.4],
"metadata": {"genre": "action"},
},
]
# prepare data
data = {
"ids": [],
"embeddings": [],
"metadatas": []
}
for vector in vectors:
data["ids"].append(vector["id"])
data["embeddings"].append(vector["values"])
data["metadatas"].append(vector["metadata"])
client.post(urljoin(api_url, f"collections/{collection['id']}/add"), headers=headers, json=data)
# search
query = {
"query_embeddings": [[0.15, 0.12, 1.23]],
"n_results": 1,
"include":["embeddings", "metadatas"]
}
results = client.post(
urljoin(api_url, f"collections/{collection['id']}/query"), headers=headers, json=query
).json()
print(results)
# delete collection
client.delete(urljoin(api_url, f"collections/{collection_name}"), headers=headers)
if __name__ == "__main__":
main()