Skip to content

Commit

Permalink
Add *much better* serialization with tofile/fromfile.
Browse files Browse the repository at this point in the history
This utilizes the underlying bitarray's tofile/fromfile
as well.

Add MANIFEST.in so that ez_setup.py is included in the
sdist.
  • Loading branch information
mariusae committed Oct 22, 2009
1 parent e9f6038 commit 764f18f
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 6 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
@@ -0,0 +1 @@
include ez_setup.py
91 changes: 86 additions & 5 deletions pybloom/pybloom.py
Expand Up @@ -34,7 +34,7 @@
"""
import math
import hashlib
from struct import unpack, pack
from struct import unpack, pack, calcsize

try:
import bitarray
Expand Down Expand Up @@ -83,6 +83,8 @@ def _make_hashfuncs(key):


class BloomFilter(object):
FILE_FMT = '<dQQQQ'

def __init__(self, capacity, error_rate=0.001):
"""Implements a space-efficient probabilistic data structure
Expand Down Expand Up @@ -117,15 +119,18 @@ def __init__(self, capacity, error_rate=0.001):
bits_per_slice = int(math.ceil(
(2 * capacity * abs(math.log(error_rate))) /
(num_slices * (math.log(2) ** 2))))
self._setup(error_rate, num_slices, bits_per_slice, capacity, 0)
self.bitarray = bitarray.bitarray(self.num_bits, endian='little')
self.bitarray.setall(False)

def _setup(self, error_rate, num_slices, bits_per_slice, capacity, count):
self.error_rate = error_rate
self.num_slices = num_slices
self.bits_per_slice = bits_per_slice
self.capacity = capacity
self.num_bits = num_slices * bits_per_slice
self.count = 0
#print '\n'.join('%s = %s' % tpl for tpl in sorted(self.__dict__.items()))
self.bitarray = bitarray.bitarray(self.num_bits)
self.bitarray.setall(False)
self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice)

def __contains__(self, key):
Expand Down Expand Up @@ -180,6 +185,37 @@ def add(self, key, skip_check=False):
self.count += 1
return False

def tofile(self, f):
"""Write the bloom filter to file object `f'. Underlying bits
are written as machine values. This is much more space
efficient than pickling the object."""
f.write(pack(self.FILE_FMT, self.error_rate, self.num_slices,
self.bits_per_slice, self.capacity, self.count))
self.bitarray.tofile(f)

@classmethod
def fromfile(cls, f, n=-1):
"""Read a bloom filter from file-object `f' serialized with
``BloomFilter.tofile''. If `n' > 0 read only so many bytes."""
headerlen = calcsize(cls.FILE_FMT)

if 0 < n < headerlen:
raise ValueError, 'n too small!'

filter = cls(1) # Bogus instantiation, we will `_setup'.
filter._setup(*unpack(cls.FILE_FMT, f.read(headerlen)))
filter.bitarray = bitarray.bitarray(endian='little')
if n > 0:
filter.bitarray.fromfile(f, n - headerlen)
else:
filter.bitarray.fromfile(f)
if filter.num_bits != filter.bitarray.length() and \
(filter.num_bits + (8 - filter.num_bits % 8)
!= filter.bitarray.length()):
raise ValueError, 'Bit length mismatch!'

return filter

def __getstate__(self):
d = self.__dict__.copy()
del d['make_hashes']
Expand All @@ -192,6 +228,7 @@ def __setstate__(self, d):
class ScalableBloomFilter(object):
SMALL_SET_GROWTH = 2 # slower, but takes up less memory
LARGE_SET_GROWTH = 4 # faster, but takes up more memory faster
FILE_FMT = '<idQd'

def __init__(self, initial_capacity=100, error_rate=0.001,
mode=SMALL_SET_GROWTH):
Expand Down Expand Up @@ -221,11 +258,14 @@ def __init__(self, initial_capacity=100, error_rate=0.001,
"""
if not error_rate or error_rate < 0:
raise ValueError("Error_Rate must be a decimal less than 0.")
self._setup(mode, 0.9, initial_capacity, error_rate)
self.filters = []

def _setup(self, mode, ratio, initial_capacity, error_rate):
self.scale = mode
self.ratio = 0.9
self.ratio = ratio
self.initial_capacity = initial_capacity
self.error_rate = error_rate
self.filters = []

def __contains__(self, key):
"""Tests a key's membership in this bloom filter.
Expand Down Expand Up @@ -277,6 +317,47 @@ def capacity(self):
def count(self):
return len(self)

def tofile(self, f):
"""Serialize this ScalableBloomFilter into the file-object
`f'."""
f.write(pack(self.FILE_FMT, self.scale, self.ratio,
self.initial_capacity, self.error_rate))

# Write #-of-filters
f.write(pack('<l', len(self.filters)))

if len(self.filters) > 0:
# Then each filter directly, with a header describing
# their lengths.
headerpos = f.tell()
headerfmt = '<' + 'Q'*(len(self.filters))
f.write('.' * calcsize(headerfmt))
filter_sizes = []
for filter in self.filters:
begin = f.tell()
filter.tofile(f)
filter_sizes.append(f.tell() - begin)

f.seek(headerpos)
f.write(pack(headerfmt, *filter_sizes))

@classmethod
def fromfile(cls, f):
"""Deserialize the ScalableBloomFilter in file object `f'."""
filter = cls()
filter._setup(*unpack(cls.FILE_FMT, f.read(calcsize(cls.FILE_FMT))))
nfilters, = unpack('<l', f.read(calcsize('<l')))
if nfilters > 0:
header_fmt = '<' + 'Q'*nfilters
bytes = f.read(calcsize(header_fmt))
filter_lengths = unpack(header_fmt, bytes)
for fl in filter_lengths:
filter.filters.append(BloomFilter.fromfile(f, fl))
else:
filter.filters = []

return filter

def __len__(self):
"""Returns the total number of elements stored in this SBF"""
return sum([f.count for f in self.filters])
Expand Down
32 changes: 31 additions & 1 deletion pybloom/tests.py
@@ -1,5 +1,9 @@
import os
import doctest
import unittest
import random
import tempfile
from pybloom import BloomFilter, ScalableBloomFilter
from unittest import TestSuite

def additional_tests():
Expand All @@ -8,4 +12,30 @@ def additional_tests():
suite = TestSuite([doctest.DocTestSuite('pybloom.pybloom')])
if os.path.exists(readme_fn):
suite.addTest(doctest.DocFileSuite(readme_fn, module_relative=False))
return suite
return suite

class Serialization(unittest.TestCase):
SIZE = 12345
EXPECTED = set([random.randint(0, 10000100) for _ in xrange(SIZE)])

def test_serialization(self):
for klass, args in [(BloomFilter, (self.SIZE,)),
(ScalableBloomFilter, ())]:
filter = klass(*args)
for item in self.EXPECTED:
filter.add(item)

# It seems bitarray is finicky about the object being an
# actual file, so we can't just use StringIO. Grr.
f = tempfile.TemporaryFile()
filter.tofile(f)
del filter

f.seek(0)
filter = klass.fromfile(f)

for item in self.EXPECTED:
self.assert_(item in filter)

if __name__ == '__main__':
unittest.main()

0 comments on commit 764f18f

Please sign in to comment.