-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(driver): add a simple cache driver
- Loading branch information
Showing
5 changed files
with
132 additions
and
129 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,102 +1,30 @@ | ||
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved." | ||
__license__ = "Apache-2.0" | ||
from typing import Iterable, Any | ||
|
||
import math | ||
import os | ||
import struct | ||
from typing import Iterable | ||
from .index import BaseIndexDriver | ||
|
||
from . import BaseRecursiveDriver | ||
from ..proto import uid | ||
|
||
if False: | ||
from ..proto import jina_pb2 | ||
|
||
|
||
class BloomFilterDriver(BaseRecursiveDriver): | ||
""" Bloom filter to test whether a doc is observed or not based on its ``doc.id``. | ||
It is used to speed up answers in a key-value storage system. | ||
Values are stored on a disk which has slow access times. Bloom filter decisions are much faster. | ||
""" | ||
|
||
def __init__(self, bit_array: int = 0, num_hash: int = 4, *args, **kwargs): | ||
""" | ||
:param bit_array: a bit array of m bits, all set to 0. | ||
:param num_hash: number of hash functions, can only be 4, 8. | ||
:param args: | ||
:param kwargs: | ||
""" | ||
super().__init__(*args, **kwargs) | ||
self._bit_array = bit_array | ||
# unpack int64 (8 bytes) to eight uint8 (1 bytes) | ||
# to simulate a group of hash functions in bloom filter | ||
if num_hash == 2: | ||
fmt = 'I' | ||
elif num_hash == 4: | ||
# 8 bytes/4 = 2 bytes = H (unsigned short) | ||
fmt = 'H' | ||
elif num_hash == 8: | ||
fmt = 'B' | ||
else: | ||
raise ValueError(f'"num_hash" must be 4 or 8 but given {num_hash}') | ||
fmt = fmt * num_hash | ||
self._num_bit = 2 ** (64 / num_hash) | ||
self._num_hash = num_hash | ||
self._hash_funcs = lambda x: struct.unpack(fmt, uid.id2bytes(x)) | ||
|
||
def __contains__(self, doc_id: str): | ||
for _r in self._hash_funcs(doc_id): | ||
if not (self._bit_array & (1 << _r)): | ||
return False | ||
return True | ||
|
||
def on_hit(self, doc: 'jina_pb2.Document'): | ||
"""Function to call when doc exists""" | ||
raise NotImplementedError | ||
|
||
def on_miss(self, doc: 'jina_pb2.Document'): | ||
"""Function to call when doc is missing""" | ||
pass | ||
|
||
def _add(self, doc_id: str): | ||
for _r in self._hash_funcs(doc_id): | ||
self._bit_array |= (1 << _r) | ||
|
||
def _flush(self): | ||
"""Write the bloom filter by writing ``_bit_array`` back""" | ||
pass | ||
|
||
@property | ||
def false_positive_rate(self) -> float: | ||
"""Returns the false positive rate with 10000 docs. | ||
The more items added, the larger the probability of false positives. | ||
""" | ||
return math.pow(1 - math.exp(-(self._num_hash * 10000 / self._num_bit)), self._num_hash) | ||
class BaseCacheDriver(BaseIndexDriver): | ||
|
||
def _apply_all(self, docs: Iterable['jina_pb2.Document'], *args, **kwargs) -> None: | ||
for doc in docs: | ||
if doc.id in self: | ||
self.on_hit(doc) | ||
for d in docs: | ||
result = self.exec[d.id] | ||
if result is None: | ||
self.on_miss(d) | ||
else: | ||
self._add(doc.id) | ||
self.on_miss(doc) | ||
self._flush() | ||
self.on_hit(d, result) | ||
|
||
def on_miss(self, doc: 'jina_pb2.Document') -> None: | ||
"""Function to call when doc is missing, the default behavior is add to cache when miss | ||
class EnvBloomFilterDriver(BloomFilterDriver): | ||
""" | ||
A :class:`BloomFilterDriver` that stores ``bit_array`` in OS environment. | ||
Just an example how to share & persist ``bit_array`` | ||
""" | ||
:param doc: the document in the request but missed in the cache | ||
""" | ||
self.exec_fn(doc.id, doc.SerializeToString()) | ||
|
||
def __init__(self, env_name: str = 'JINA_BLOOMFILTER_1', *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self._env_name = env_name | ||
self._bit_array = int(os.environ.get(env_name, '0'), 2) | ||
def on_hit(self, req_doc: 'jina_pb2.Document', hit_result: Any) -> None: | ||
""" Function to call when doc is hit | ||
def _flush(self): | ||
os.environ[self._env_name] = bin(self._bit_array)[2:] | ||
:param req_doc: the document in the request and hitted in the cache | ||
:param hit_result: the hit result returned by the cache | ||
:return: | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import Optional | ||
|
||
import numpy as np | ||
|
||
from . import BaseKVIndexer | ||
from ...proto import uid | ||
|
||
|
||
class InMemoryIDCache(BaseKVIndexer): | ||
"""Store doc ids in a int64 set and persistent it to a numpy array """ | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.handler_mutex = False #: for Cache we need to release the handler mutex to allow RW at the same time | ||
|
||
def add(self, doc_id: str, *args, **kwargs): | ||
d_id = uid.id2hash(doc_id) | ||
self.query_handler.add(d_id) | ||
self._size += 1 | ||
self.write_handler.write(np.int64(d_id).tobytes()) | ||
|
||
def query(self, doc_id: str, *args, **kwargs) -> Optional[bool]: | ||
if self.query_handler: | ||
d_id = uid.id2hash(doc_id) | ||
return (d_id in self.query_handler) or None | ||
|
||
@property | ||
def is_exist(self) -> bool: | ||
""" Always return true, delegate to :meth:`get_query_handler` | ||
:return: True | ||
""" | ||
return True | ||
|
||
def get_query_handler(self): | ||
if super().is_exist: | ||
with open(self.index_abspath, 'rb') as fp: | ||
return set(np.frombuffer(fp.read(), dtype=np.int64)) | ||
else: | ||
return set() | ||
|
||
def get_add_handler(self): | ||
return open(self.index_abspath, 'ab') | ||
|
||
def get_create_handler(self): | ||
return open(self.index_abspath, 'wb') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,68 @@ | ||
import os | ||
from typing import Any | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from jina.drivers.cache import BloomFilterDriver, EnvBloomFilterDriver | ||
from tests import random_docs | ||
import os | ||
from jina.drivers.cache import BaseCacheDriver | ||
from jina.executors.indexers.cache import InMemoryIDCache | ||
from jina.proto import jina_pb2, uid | ||
from tests import random_docs, rm_files | ||
|
||
filename = 'test-tmp.bin' | ||
|
||
@pytest.mark.parametrize('num_hash', [4, 8]) | ||
def test_cache_driver_twice(num_hash): | ||
docs = list(random_docs(10)) | ||
driver = BloomFilterDriver(num_hash=num_hash) | ||
driver._apply_all(docs) | ||
|
||
with pytest.raises(NotImplementedError): | ||
# duplicate docs | ||
driver._apply_all(docs) | ||
class MockCacheDriver(BaseCacheDriver): | ||
|
||
# new docs | ||
@property | ||
def exec_fn(self): | ||
return self._exec_fn | ||
|
||
def on_hit(self, req_doc: 'jina_pb2.Document', hit_result: Any) -> None: | ||
raise NotImplementedError | ||
|
||
|
||
def test_cache_driver_twice(): | ||
docs = list(random_docs(10)) | ||
driver._apply_all(docs) | ||
driver = MockCacheDriver() | ||
with InMemoryIDCache(filename) as executor: | ||
assert not executor.handler_mutex | ||
driver.attach(executor=executor, pea=None) | ||
|
||
driver._traverse_apply(docs) | ||
|
||
with pytest.raises(NotImplementedError): | ||
# duplicate docs | ||
driver._traverse_apply(docs) | ||
|
||
@pytest.mark.parametrize('num_hash', [4, 8]) | ||
def test_cache_driver_env(num_hash): | ||
# new docs | ||
docs = list(random_docs(10)) | ||
driver._traverse_apply(docs) | ||
|
||
# check persistence | ||
assert os.path.exists(filename) | ||
rm_files([filename]) | ||
|
||
|
||
def test_cache_driver_from_file(): | ||
docs = list(random_docs(10)) | ||
driver = EnvBloomFilterDriver(num_hash=num_hash) | ||
assert os.environ.get(driver._env_name, None) is None | ||
driver._apply_all(docs) | ||
|
||
with pytest.raises(NotImplementedError): | ||
# duplicate docs | ||
driver._apply_all(docs) | ||
|
||
# now start a new one | ||
# should fail again, as bloom filter is persisted in os.env | ||
with pytest.raises(NotImplementedError): | ||
driver = EnvBloomFilterDriver(num_hash=num_hash) | ||
driver._apply_all(docs) | ||
|
||
assert os.environ.get(driver._env_name, None) is not None | ||
os.environ.pop(driver._env_name) | ||
with open(filename, 'wb') as fp: | ||
fp.write(np.array([uid.id2hash(d.id) for d in docs], dtype=np.int64).tobytes()) | ||
|
||
driver = MockCacheDriver() | ||
with InMemoryIDCache(filename) as executor: | ||
assert not executor.handler_mutex | ||
driver.attach(executor=executor, pea=None) | ||
|
||
with pytest.raises(NotImplementedError): | ||
# duplicate docs | ||
driver._traverse_apply(docs) | ||
|
||
# new docs | ||
docs = list(random_docs(10)) | ||
driver._traverse_apply(docs) | ||
|
||
# check persistence | ||
assert os.path.exists(filename) | ||
rm_files([filename]) | ||
|