Skip to content

Commit

Permalink
feat(driver): add bloom filter driver
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Oct 11, 2020
1 parent c2cad81 commit 51c5e45
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
23 changes: 18 additions & 5 deletions jina/drivers/cache.py
@@ -1,6 +1,7 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

import math
import os
import struct
from typing import Iterable
Expand All @@ -18,26 +19,30 @@ class BloomFilterDriver(BaseRecursiveDriver):
Values are stored on a disk which has slow access times. Bloom filter decisions are much faster.
"""

def __init__(self, bit_array: int = 0, num_hash: int = 8, *args, **kwargs):
def __init__(self, bit_array: int = 0, num_hash: int = 4, *args, **kwargs):
"""
:param bit_array: a bit array of m bits, all set to 0.
:param num_hash: number of hash functions, can only be 4, 8.
larger value, slower, but more memory efficient
:param args:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self._bit_array = bit_array
# unpack int64 (8 bytes) to eight uint8 (1 bytes)
# to simulate a group of hash functions in bloom filter
if num_hash == 4:
if num_hash == 2:
fmt = 'I'
elif num_hash == 4:
# 8 bytes/4 = 2 bytes = H (unsigned short)
fmt = 'H' * 4
fmt = 'H'
elif num_hash == 8:
fmt = 'B' * 8
fmt = 'B'
else:
raise ValueError(f'"num_hash" must be 4 or 8 but given {num_hash}')
fmt = fmt * num_hash
self._num_bit = 2 ** (64 / num_hash)
self._num_hash = num_hash
self._hash_funcs = lambda x: struct.unpack(fmt, uid.id2bytes(x))

def __contains__(self, doc_id: str):
Expand All @@ -62,6 +67,14 @@ def _flush(self):
"""Write the bloom filter by writing ``_bit_array`` back"""
pass

@property
def false_positive_rate(self) -> float:
"""Returns the false positive rate with 10000 docs.
The more items added, the larger the probability of false positives.
"""
return math.pow(1 - math.exp(-(self._num_hash * 10000 / self._num_bit)), self._num_hash)

def _apply_all(self, docs: Iterable['jina_pb2.Document'], *args, **kwargs) -> None:
for doc in docs:
if doc.id in self:
Expand Down
5 changes: 2 additions & 3 deletions tests/unit/drivers/test_cache_driver.py
@@ -1,9 +1,8 @@
import os

import pytest

from jina.drivers.cache import BloomFilterDriver, EnvBloomFilterDriver
from tests import random_docs
import os


@pytest.mark.parametrize('num_hash', [4, 8])
Expand Down Expand Up @@ -39,4 +38,4 @@ def test_cache_driver_env(num_hash):
driver._apply_all(docs)

assert os.environ.get(driver._env_name, None) is not None
print(os.environ.pop(driver._env_name))
os.environ.pop(driver._env_name)

0 comments on commit 51c5e45

Please sign in to comment.