-
Notifications
You must be signed in to change notification settings - Fork 705
/
module.py
64 lines (54 loc) · 2.59 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import subprocess
import sys
import pgvector.psycopg
import psycopg
from ..base.module import BaseANN
class PGVector(BaseANN):
def __init__(self, metric, method_param):
self._metric = metric
self._m = method_param['M']
self._ef_construction = method_param['efConstruction']
self._cur = None
if metric == "angular":
self._query = "SELECT id FROM items ORDER BY embedding <=> %s LIMIT %s"
elif metric == "euclidean":
self._query = "SELECT id FROM items ORDER BY embedding <-> %s LIMIT %s"
else:
raise RuntimeError(f"unknown metric {metric}")
def fit(self, X):
subprocess.run("service postgresql start", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
conn = psycopg.connect(user="ann", password="ann", dbname="ann", autocommit=True)
pgvector.psycopg.register_vector(conn)
cur = conn.cursor()
cur.execute("DROP TABLE IF EXISTS items")
cur.execute("CREATE TABLE items (id int, embedding vector(%d))" % X.shape[1])
cur.execute("ALTER TABLE items ALTER COLUMN embedding SET STORAGE PLAIN")
print("copying data...")
with cur.copy("COPY items (id, embedding) FROM STDIN WITH (FORMAT BINARY)") as copy:
copy.set_types(["int4", "vector"])
for i, embedding in enumerate(X):
copy.write_row((i, embedding))
print("creating index...")
if self._metric == "angular":
cur.execute(
"CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops) WITH (m = %d, ef_construction = %d)" % (self._m, self._ef_construction)
)
elif self._metric == "euclidean":
cur.execute("CREATE INDEX ON items USING hnsw (embedding vector_l2_ops) WITH (m = %d, ef_construction = %d)" % (self._m, self._ef_construction))
else:
raise RuntimeError(f"unknown metric {self._metric}")
print("done!")
self._cur = cur
def set_query_arguments(self, ef_search):
self._ef_search = ef_search
self._cur.execute("SET hnsw.ef_search = %d" % ef_search)
def query(self, v, n):
self._cur.execute(self._query, (v, n), binary=True, prepare=True)
return [id for id, in self._cur.fetchall()]
def get_memory_usage(self):
if self._cur is None:
return 0
self._cur.execute("SELECT pg_relation_size('items_embedding_idx')")
return self._cur.fetchone()[0] / 1024
def __str__(self):
return f"PGVector(m={self._m}, ef_construction={self._ef_construction}, ef_search={self._ef_search})"