Skip to content

Commit

Permalink
chore: blackify all
Browse files Browse the repository at this point in the history
Signed-off-by: Joan Fontanals Martinez <joan.martinez@jina.ai>
  • Loading branch information
JoanFM committed Jul 20, 2023
1 parent ddb47bc commit 3469de8
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 106 deletions.
4 changes: 1 addition & 3 deletions annlite/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ def filter_cells(

# reordering the results from multiple cells
if order_by and len(cells) > 1:
result = sorted(
result, key=lambda d: d['order_by'], reverse=not ascending
)
result = sorted(result, key=lambda d: d['order_by'], reverse=not ascending)
if limit > 0:
result = result[:limit]

Expand Down
135 changes: 68 additions & 67 deletions annlite/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,23 @@ class AnnLite(CellContainer):
"""

def __init__(
self,
n_dim: int,
metric: Union[str, Metric] = 'cosine',
n_cells: int = 1,
n_subvectors: Optional[int] = None,
n_clusters: Optional[int] = 256,
n_probe: int = 16,
n_components: Optional[int] = None,
initial_size: Optional[int] = None,
expand_step_size: int = 10240,
columns: Optional[Union[Dict, List]] = None,
filterable_attrs: Optional[Dict] = None,
data_path: Union[Path, str] = Path('./data'),
create_if_missing: bool = True,
read_only: bool = False,
verbose: bool = False,
**kwargs,
self,
n_dim: int,
metric: Union[str, Metric] = 'cosine',
n_cells: int = 1,
n_subvectors: Optional[int] = None,
n_clusters: Optional[int] = 256,
n_probe: int = 16,
n_components: Optional[int] = None,
initial_size: Optional[int] = None,
expand_step_size: int = 10240,
columns: Optional[Union[Dict, List]] = None,
filterable_attrs: Optional[Dict] = None,
data_path: Union[Path, str] = Path('./data'),
create_if_missing: bool = True,
read_only: bool = False,
verbose: bool = False,
**kwargs,
):
setup_logging(verbose)

Expand All @@ -81,7 +81,7 @@ def __init__(

if n_subvectors:
assert (
n_dim % n_subvectors == 0
n_dim % n_subvectors == 0
), '"n_dim" needs to be divisible by "n_subvectors"'
self.n_dim = n_dim
self.n_components = n_components
Expand Down Expand Up @@ -186,7 +186,7 @@ def __init__(
def _sanity_check(self, x: 'np.ndarray'):
assert x.ndim == 2, 'inputs must be a 2D array'
assert (
x.shape[1] == self.n_dim
x.shape[1] == self.n_dim
), f'inputs must have the same dimension as the index , got {x.shape[1]}, expected {self.n_dim}'

return x.shape
Expand Down Expand Up @@ -230,7 +230,7 @@ def train(self, x: 'np.ndarray', auto_save: bool = True, force_train: bool = Fal
self.dump_model()

def partial_train(
self, x: np.ndarray, auto_save: bool = True, force_train: bool = False
self, x: np.ndarray, auto_save: bool = True, force_train: bool = False
):
"""Partially train the index with the given data.
Expand Down Expand Up @@ -293,11 +293,11 @@ def index(self, docs: 'List', **kwargs):
return super(AnnLite, self).insert(x, assigned_cells, docs)

def update(
self,
docs: 'List',
raise_errors_on_not_found: bool = False,
insert_if_not_found: bool = True,
**kwargs,
self,
docs: 'List',
raise_errors_on_not_found: bool = False,
insert_if_not_found: bool = True,
**kwargs,
):
"""Update the documents in the index.
Expand Down Expand Up @@ -331,12 +331,12 @@ def update(
)

def search(
self,
docs: 'List',
filter: Optional[dict] = None,
limit: int = 10,
include_metadata: bool = True,
**kwargs,
self,
docs: 'List',
filter: Optional[dict] = None,
limit: int = 10,
include_metadata: bool = True,
**kwargs,
):
"""Search the index, and attach matches to the query Documents in `docs`
Expand All @@ -356,11 +356,11 @@ def search(
return match_docs

def search_by_vectors(
self,
query_np: 'np.ndarray',
filter: Optional[dict] = None,
limit: int = 10,
include_metadata: bool = True,
self,
query_np: 'np.ndarray',
filter: Optional[dict] = None,
limit: int = 10,
include_metadata: bool = True,
):
"""Search the index by vectors, and return the matches.
Expand All @@ -384,13 +384,13 @@ def search_by_vectors(
return match_dists, match_docs

def filter(
self,
filter: Dict,
limit: int = 10,
offset: int = 0,
order_by: Optional[str] = None,
ascending: bool = True,
include_metadata: bool = True,
self,
filter: Dict,
limit: int = 10,
offset: int = 0,
order_by: Optional[str] = None,
ascending: bool = True,
include_metadata: bool = True,
):
"""Find the documents by the filter.
Expand Down Expand Up @@ -427,12 +427,12 @@ def get_doc_by_id(self, doc_id: str):
return self._get_doc_by_id(doc_id)

def get_docs(
self,
filter: Optional[dict] = None,
limit: int = 10,
offset: int = 0,
order_by: Optional[str] = None,
ascending: bool = True,
self,
filter: Optional[dict] = None,
limit: int = 10,
offset: int = 0,
order_by: Optional[str] = None,
ascending: bool = True,
):
"""Get the documents.
Expand Down Expand Up @@ -480,11 +480,11 @@ def _cell_selection(self, query_np, limit):
return cells

def search_numpy(
self,
query_np: 'np.ndarray',
filter: Dict = {},
limit: int = 10,
**kwargs,
self,
query_np: 'np.ndarray',
filter: Dict = {},
limit: int = 10,
**kwargs,
):
"""Search the index and return distances to the query and ids of the closest documents.
Expand Down Expand Up @@ -519,9 +519,9 @@ def _search_numpy(self, query_np: 'np.ndarray', filter: Dict = {}, limit: int =
return dists, ids

def delete(
self,
docs: Union['List[Dict]', List[str]],
raise_errors_on_not_found: bool = False,
self,
docs: Union['List[Dict]', List[str]],
raise_errors_on_not_found: bool = False,
):
"""Delete entries from the index by id
Expand All @@ -531,7 +531,8 @@ def delete(

if len(docs) > 0:
super().delete(
docs if isinstance(docs[0], str) else [doc['id'] for doc in docs], raise_errors_on_not_found
docs if isinstance(docs[0], str) else [doc['id'] for doc in docs],
raise_errors_on_not_found,
)

def clear(self):
Expand Down Expand Up @@ -617,9 +618,9 @@ def index_hash(self):
def index_path(self):
if self.index_hash:
return (
self.data_path
/ f'snapshot-{self.params_hash}'
/ f'{self.index_hash}-SNAPSHOT'
self.data_path
/ f'snapshot-{self.params_hash}'
/ f'{self.index_hash}-SNAPSHOT'
)
return None

Expand Down Expand Up @@ -863,14 +864,14 @@ def _rebuild_index_from_remote(self, source_name: str, token: str):
# default has only one cell
shutil.unpack_archive(zip_file, self.data_path / f'cell_{cell_id}')
for f in list(
(
self.data_path
/ f'cell_{cell_id}'
/ zip_file.name.split('.zip')[0]
).iterdir()
(
self.data_path
/ f'cell_{cell_id}'
/ zip_file.name.split('.zip')[0]
).iterdir()
):
origin_database_path = (
self.data_path / f'cell_{cell_id}' / f.name
self.data_path / f'cell_{cell_id}' / f.name
)
if origin_database_path.exists():
origin_database_path.unlink()
Expand Down
12 changes: 6 additions & 6 deletions annlite/storage/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def insert(self, docs: 'List'):
write_opt.sync = True
batch_size = 0
for doc in docs:
#TODO: How to serialize a dict
#write_batch.put(doc.id.encode(), doc.to_bytes(**self._serialize_config))
# TODO: How to serialize a dict
# write_batch.put(doc.id.encode(), doc.to_bytes(**self._serialize_config))
write_batch.put(doc['id'].encode(), pickle.dumps(doc))
batch_size += 1
self._db.write(write_batch, write_opt=write_opt)
Expand All @@ -68,7 +68,7 @@ def update(self, docs: 'List'):
if key not in self._db:
raise ValueError(f'The Doc ({doc["id"]}) does not exist in database!')

#write_batch.put(key, doc.to_bytes(**self._serialize_config))
# write_batch.put(key, doc.to_bytes(**self._serialize_config))
# TODO: Serialize
write_batch.put(key, pickle.dumps(doc))
self._db.write(write_batch, write_opt=write_opt)
Expand All @@ -89,7 +89,7 @@ def get(self, doc_ids: Union[str, list]) -> List:

for doc_bytes in self._db[[k.encode() for k in doc_ids]]:
if doc_bytes:
#docs.append(Document.from_bytes(doc_bytes, **self._serialize_config))
# docs.append(Document.from_bytes(doc_bytes, **self._serialize_config))
# TODO: Deserialize
docs.append(pickle.loads(doc_bytes))

Expand Down Expand Up @@ -144,8 +144,8 @@ def batched_iterator(self, batch_size: int = 1, **kwargs) -> 'List':
read_opt = ReadOptions()

for value in self._db.values(read_opt=read_opt):
#doc = Document.from_bytes(value, **self._serialize_config)
#TODO: Deserialize
# doc = Document.from_bytes(value, **self._serialize_config)
# TODO: Deserialize
docs.append(pickle.loads(value))
count += 1

Expand Down
5 changes: 1 addition & 4 deletions annlite/storage/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,7 @@ def insert(
for doc in docs:
doc_value = tuple(
[doc['id']]
+ [
_converting(doc[c]) if c in doc else None
for c in self.columns[2:]
]
+ [_converting(doc[c]) if c in doc else None for c in self.columns[2:]]
)
values.append(doc_value)
docs_size += 1
Expand Down
4 changes: 1 addition & 3 deletions scripts/get-all-test-paths.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ set -ex
BATCH_SIZE=5

declare -a array1=( "tests/test_*.py" )
declare -a array2=( "tests/docarray/test_*.py" )
declare -a array3=( "tests/executor/test_*.py" )
dest=( "${array1[@]}" "${array2[@]}" "${array3[@]}" )
dest=( "${array1[@]}" )

printf '%s\n' "${dest[@]}" | jq -R . | jq -cs .
18 changes: 9 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
@pytest.fixture(scope='session')
def docs():
return [
dict(id='doc1', embedding=np.array([1, 0, 0, 0])),
dict(id='doc2', embedding=np.array([0, 1, 0, 0])),
dict(id='doc3', embedding=np.array([0, 0, 1, 0])),
dict(id='doc4', embedding=np.array([0, 0, 0, 1])),
dict(id='doc5', embedding=np.array([1, 0, 1, 0])),
dict(id='doc6', embedding=np.array([0, 1, 0, 1])),
]
dict(id='doc1', embedding=np.array([1, 0, 0, 0])),
dict(id='doc2', embedding=np.array([0, 1, 0, 0])),
dict(id='doc3', embedding=np.array([0, 0, 1, 0])),
dict(id='doc4', embedding=np.array([0, 0, 0, 1])),
dict(id='doc5', embedding=np.array([1, 0, 1, 0])),
dict(id='doc6', embedding=np.array([0, 1, 0, 1])),
]


@pytest.fixture(scope='session')
def update_docs():
return [
dict(id='doc1', embedding=np.array([0, 0, 0, 1])),
]
dict(id='doc1', embedding=np.array([0, 0, 0, 1])),
]


@pytest.fixture(autouse=True)
Expand Down
11 changes: 7 additions & 4 deletions tests/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def test_update_legal(annlite_with_data):
index = annlite_with_data

updated_X = np.random.random((Nt, D)).astype(np.float32)
updated_docs = [dict(id=f'{i}', embedding=updated_X[i], x=random.random()) for i in range(Nt)]
updated_docs = [
dict(id=f'{i}', embedding=updated_X[i], x=random.random()) for i in range(Nt)
]

index.update(updated_docs)
matches = index.search(updated_docs)
Expand All @@ -56,9 +58,10 @@ def test_update_illegal(annlite_with_data):
index = annlite_with_data

updated_X = np.random.random((Nt, D)).astype(np.float32)
updated_docs = [dict(
id=f'{i}_wrong', embedding=updated_X[i], x=random.random()
) for i in range(Nt)]
updated_docs = [
dict(id=f'{i}_wrong', embedding=updated_X[i], x=random.random())
for i in range(Nt)
]

with pytest.raises(Exception):
index.update(
Expand Down
5 changes: 4 additions & 1 deletion tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def test_filter_with_limit_offset(tmpfile, limit, offset, order_by, ascending):
)
X = np.random.random((N, D)).astype(np.float32)

docs = [dict(id=f'{i}', embedding=X[i], x=random.random(), y=random.random()) for i in range(N)]
docs = [
dict(id=f'{i}', embedding=X[i], x=random.random(), y=random.random())
for i in range(N)
]
index.index(docs)

matches = index.filter(
Expand Down
Loading

0 comments on commit 3469de8

Please sign in to comment.