Skip to content

Commit

Permalink
feat(driver): add a simple cache driver
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Oct 12, 2020
1 parent 51c5e45 commit 8a540a9
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 129 deletions.
112 changes: 20 additions & 92 deletions jina/drivers/cache.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,30 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"
from typing import Iterable, Any

import math
import os
import struct
from typing import Iterable
from .index import BaseIndexDriver

from . import BaseRecursiveDriver
from ..proto import uid

if False:
from ..proto import jina_pb2


class BloomFilterDriver(BaseRecursiveDriver):
""" Bloom filter to test whether a doc is observed or not based on its ``doc.id``.
It is used to speed up answers in a key-value storage system.
Values are stored on a disk which has slow access times. Bloom filter decisions are much faster.
"""

def __init__(self, bit_array: int = 0, num_hash: int = 4, *args, **kwargs):
"""
:param bit_array: a bit array of m bits, all set to 0.
:param num_hash: number of hash functions, can only be 4, 8.
:param args:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self._bit_array = bit_array
# unpack int64 (8 bytes) to eight uint8 (1 bytes)
# to simulate a group of hash functions in bloom filter
if num_hash == 2:
fmt = 'I'
elif num_hash == 4:
# 8 bytes/4 = 2 bytes = H (unsigned short)
fmt = 'H'
elif num_hash == 8:
fmt = 'B'
else:
raise ValueError(f'"num_hash" must be 4 or 8 but given {num_hash}')
fmt = fmt * num_hash
self._num_bit = 2 ** (64 / num_hash)
self._num_hash = num_hash
self._hash_funcs = lambda x: struct.unpack(fmt, uid.id2bytes(x))

def __contains__(self, doc_id: str):
for _r in self._hash_funcs(doc_id):
if not (self._bit_array & (1 << _r)):
return False
return True

def on_hit(self, doc: 'jina_pb2.Document'):
"""Function to call when doc exists"""
raise NotImplementedError

def on_miss(self, doc: 'jina_pb2.Document'):
"""Function to call when doc is missing"""
pass

def _add(self, doc_id: str):
for _r in self._hash_funcs(doc_id):
self._bit_array |= (1 << _r)

def _flush(self):
"""Write the bloom filter by writing ``_bit_array`` back"""
pass

@property
def false_positive_rate(self) -> float:
"""Returns the false positive rate with 10000 docs.
The more items added, the larger the probability of false positives.
"""
return math.pow(1 - math.exp(-(self._num_hash * 10000 / self._num_bit)), self._num_hash)
class BaseCacheDriver(BaseIndexDriver):

def _apply_all(self, docs: Iterable['jina_pb2.Document'], *args, **kwargs) -> None:
for doc in docs:
if doc.id in self:
self.on_hit(doc)
for d in docs:
result = self.exec[d.id]
if result is None:
self.on_miss(d)
else:
self._add(doc.id)
self.on_miss(doc)
self._flush()
self.on_hit(d, result)

def on_miss(self, doc: 'jina_pb2.Document') -> None:
"""Function to call when doc is missing, the default behavior is add to cache when miss
class EnvBloomFilterDriver(BloomFilterDriver):
"""
A :class:`BloomFilterDriver` that stores ``bit_array`` in OS environment.
Just an example how to share & persist ``bit_array``
"""
:param doc: the document in the request but missed in the cache
"""
self.exec_fn(doc.id, doc.SerializeToString())

def __init__(self, env_name: str = 'JINA_BLOOMFILTER_1', *args, **kwargs):
super().__init__(*args, **kwargs)
self._env_name = env_name
self._bit_array = int(os.environ.get(env_name, '0'), 2)
def on_hit(self, req_doc: 'jina_pb2.Document', hit_result: Any) -> None:
""" Function to call when doc is hit
def _flush(self):
os.environ[self._env_name] = bin(self._bit_array)[2:]
:param req_doc: the document in the request and hitted in the cache
:param hit_result: the hit result returned by the cache
:return:
"""
pass
4 changes: 3 additions & 1 deletion jina/drivers/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class BaseIndexDriver(BaseExecutableDriver):
"""Drivers inherited from this Driver will bind :meth:`craft` by default """
"""Drivers inherited from this Driver will bind :meth:`add` by default """

def __init__(self, executor: str = None, method: str = 'add', *args, **kwargs):
super().__init__(executor, method, *args, **kwargs)
Expand Down Expand Up @@ -42,3 +42,5 @@ def _apply_all(self, docs: Iterable['jina_pb2.Document'], *args, **kwargs) -> No
keys = [uid.id2hash(doc.id) for doc in docs]
values = [doc.SerializeToString() for doc in docs]
self.exec_fn(keys, values)


10 changes: 5 additions & 5 deletions jina/executors/indexers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
__license__ = "Apache-2.0"

import os
from typing import Tuple, Union, List, Iterator, Optional
from typing import Tuple, Union, List, Iterator, Optional, Any

import numpy as np

from .. import BaseExecutor
from ..compound import CompoundExecutor
from ...helper import call_obj_fn, cached_property, get_readable_size

if False:
from ...proto import jina_pb2


class BaseIndexer(BaseExecutor):
""" base class for storing and searching any kind of data structure
Expand Down Expand Up @@ -202,14 +199,17 @@ class BaseKVIndexer(BaseIndexer):
def add(self, keys: Iterator[int], values: Iterator[bytes], *args, **kwargs):
raise NotImplementedError

def query(self, key: int) -> Optional['jina_pb2.Document']:
def query(self, key: Any) -> Optional[Any]:
""" Find the protobuf chunk/doc using id
:param key: ``id``
:return: protobuf chunk or protobuf document
"""
raise NotImplementedError

def __getitem__(self, key: Any) -> Optional[Any]:
return self.query(key)


class CompoundIndexer(CompoundExecutor):
"""A Frequently used pattern for combining A :class:`BaseVectorIndexer` and :class:`BaseKVIndexer`.
Expand Down
46 changes: 46 additions & 0 deletions jina/executors/indexers/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional

import numpy as np

from . import BaseKVIndexer
from ...proto import uid


class InMemoryIDCache(BaseKVIndexer):
"""Store doc ids in a int64 set and persistent it to a numpy array """

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.handler_mutex = False #: for Cache we need to release the handler mutex to allow RW at the same time

def add(self, doc_id: str, *args, **kwargs):
d_id = uid.id2hash(doc_id)
self.query_handler.add(d_id)
self._size += 1
self.write_handler.write(np.int64(d_id).tobytes())

def query(self, doc_id: str, *args, **kwargs) -> Optional[bool]:
if self.query_handler:
d_id = uid.id2hash(doc_id)
return (d_id in self.query_handler) or None

@property
def is_exist(self) -> bool:
""" Always return true, delegate to :meth:`get_query_handler`
:return: True
"""
return True

def get_query_handler(self):
if super().is_exist:
with open(self.index_abspath, 'rb') as fp:
return set(np.frombuffer(fp.read(), dtype=np.int64))
else:
return set()

def get_add_handler(self):
return open(self.index_abspath, 'ab')

def get_create_handler(self):
return open(self.index_abspath, 'wb')
89 changes: 58 additions & 31 deletions tests/unit/drivers/test_cache_driver.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,68 @@
import os
from typing import Any

import numpy as np
import pytest

from jina.drivers.cache import BloomFilterDriver, EnvBloomFilterDriver
from tests import random_docs
import os
from jina.drivers.cache import BaseCacheDriver
from jina.executors.indexers.cache import InMemoryIDCache
from jina.proto import jina_pb2, uid
from tests import random_docs, rm_files

filename = 'test-tmp.bin'

@pytest.mark.parametrize('num_hash', [4, 8])
def test_cache_driver_twice(num_hash):
docs = list(random_docs(10))
driver = BloomFilterDriver(num_hash=num_hash)
driver._apply_all(docs)

with pytest.raises(NotImplementedError):
# duplicate docs
driver._apply_all(docs)
class MockCacheDriver(BaseCacheDriver):

# new docs
@property
def exec_fn(self):
return self._exec_fn

def on_hit(self, req_doc: 'jina_pb2.Document', hit_result: Any) -> None:
raise NotImplementedError


def test_cache_driver_twice():
docs = list(random_docs(10))
driver._apply_all(docs)
driver = MockCacheDriver()
with InMemoryIDCache(filename) as executor:
assert not executor.handler_mutex
driver.attach(executor=executor, pea=None)

driver._traverse_apply(docs)

with pytest.raises(NotImplementedError):
# duplicate docs
driver._traverse_apply(docs)

@pytest.mark.parametrize('num_hash', [4, 8])
def test_cache_driver_env(num_hash):
# new docs
docs = list(random_docs(10))
driver._traverse_apply(docs)

# check persistence
assert os.path.exists(filename)
rm_files([filename])


def test_cache_driver_from_file():
docs = list(random_docs(10))
driver = EnvBloomFilterDriver(num_hash=num_hash)
assert os.environ.get(driver._env_name, None) is None
driver._apply_all(docs)

with pytest.raises(NotImplementedError):
# duplicate docs
driver._apply_all(docs)

# now start a new one
# should fail again, as bloom filter is persisted in os.env
with pytest.raises(NotImplementedError):
driver = EnvBloomFilterDriver(num_hash=num_hash)
driver._apply_all(docs)

assert os.environ.get(driver._env_name, None) is not None
os.environ.pop(driver._env_name)
with open(filename, 'wb') as fp:
fp.write(np.array([uid.id2hash(d.id) for d in docs], dtype=np.int64).tobytes())

driver = MockCacheDriver()
with InMemoryIDCache(filename) as executor:
assert not executor.handler_mutex
driver.attach(executor=executor, pea=None)

with pytest.raises(NotImplementedError):
# duplicate docs
driver._traverse_apply(docs)

# new docs
docs = list(random_docs(10))
driver._traverse_apply(docs)

# check persistence
assert os.path.exists(filename)
rm_files([filename])

0 comments on commit 8a540a9

Please sign in to comment.