Permalink
Browse files

Add *much better* serialization with tofile/fromfile.

This utilizes the underlying bitarray's tofile/fromfile
as well.

Add MANIFEST.in so that ez_setup.py is included in the
sdist.
  • Loading branch information...
1 parent e9f6038 commit 764f18fc98a43c58c90d3dd41cb149b1deda28e1 @mariusae mariusae committed Oct 22, 2009
Showing with 118 additions and 6 deletions.
  1. +1 −0 MANIFEST.in
  2. +86 −5 pybloom/pybloom.py
  3. +31 −1 pybloom/tests.py
View
@@ -0,0 +1 @@
+include ez_setup.py
View
@@ -34,7 +34,7 @@
"""
import math
import hashlib
-from struct import unpack, pack
+from struct import unpack, pack, calcsize
try:
import bitarray
@@ -83,6 +83,8 @@ def _make_hashfuncs(key):
class BloomFilter(object):
+ FILE_FMT = '<dQQQQ'
+
def __init__(self, capacity, error_rate=0.001):
"""Implements a space-efficient probabilistic data structure
@@ -117,15 +119,18 @@ def __init__(self, capacity, error_rate=0.001):
bits_per_slice = int(math.ceil(
(2 * capacity * abs(math.log(error_rate))) /
(num_slices * (math.log(2) ** 2))))
+ self._setup(error_rate, num_slices, bits_per_slice, capacity, 0)
+ self.bitarray = bitarray.bitarray(self.num_bits, endian='little')
+ self.bitarray.setall(False)
+
+ def _setup(self, error_rate, num_slices, bits_per_slice, capacity, count):
self.error_rate = error_rate
self.num_slices = num_slices
self.bits_per_slice = bits_per_slice
self.capacity = capacity
self.num_bits = num_slices * bits_per_slice
self.count = 0
#print '\n'.join('%s = %s' % tpl for tpl in sorted(self.__dict__.items()))
- self.bitarray = bitarray.bitarray(self.num_bits)
- self.bitarray.setall(False)
self.make_hashes = make_hashfuncs(self.num_slices, self.bits_per_slice)
def __contains__(self, key):
@@ -180,6 +185,37 @@ def add(self, key, skip_check=False):
self.count += 1
return False
+ def tofile(self, f):
+ """Write the bloom filter to file object `f'. Underlying bits
+ are written as machine values. This is much more space
+ efficient than pickling the object."""
+ f.write(pack(self.FILE_FMT, self.error_rate, self.num_slices,
+ self.bits_per_slice, self.capacity, self.count))
+ self.bitarray.tofile(f)
+
+ @classmethod
+ def fromfile(cls, f, n=-1):
+ """Read a bloom filter from file-object `f' serialized with
+ ``BloomFilter.tofile''. If `n' > 0 read only so many bytes."""
+ headerlen = calcsize(cls.FILE_FMT)
+
+ if 0 < n < headerlen:
+ raise ValueError, 'n too small!'
+
+ filter = cls(1) # Bogus instantiation, we will `_setup'.
+ filter._setup(*unpack(cls.FILE_FMT, f.read(headerlen)))
+ filter.bitarray = bitarray.bitarray(endian='little')
+ if n > 0:
+ filter.bitarray.fromfile(f, n - headerlen)
+ else:
+ filter.bitarray.fromfile(f)
+ if filter.num_bits != filter.bitarray.length() and \
+ (filter.num_bits + (8 - filter.num_bits % 8)
+ != filter.bitarray.length()):
+ raise ValueError, 'Bit length mismatch!'
+
+ return filter
+
def __getstate__(self):
d = self.__dict__.copy()
del d['make_hashes']
@@ -192,6 +228,7 @@ def __setstate__(self, d):
class ScalableBloomFilter(object):
SMALL_SET_GROWTH = 2 # slower, but takes up less memory
LARGE_SET_GROWTH = 4 # faster, but takes up more memory faster
+ FILE_FMT = '<idQd'
def __init__(self, initial_capacity=100, error_rate=0.001,
mode=SMALL_SET_GROWTH):
@@ -221,11 +258,14 @@ def __init__(self, initial_capacity=100, error_rate=0.001,
"""
if not error_rate or error_rate < 0:
raise ValueError("Error_Rate must be a decimal less than 0.")
+ self._setup(mode, 0.9, initial_capacity, error_rate)
+ self.filters = []
+
+ def _setup(self, mode, ratio, initial_capacity, error_rate):
self.scale = mode
- self.ratio = 0.9
+ self.ratio = ratio
self.initial_capacity = initial_capacity
self.error_rate = error_rate
- self.filters = []
def __contains__(self, key):
"""Tests a key's membership in this bloom filter.
@@ -277,6 +317,47 @@ def capacity(self):
def count(self):
return len(self)
+ def tofile(self, f):
+ """Serialize this ScalableBloomFilter into the file-object
+ `f'."""
+ f.write(pack(self.FILE_FMT, self.scale, self.ratio,
+ self.initial_capacity, self.error_rate))
+
+ # Write #-of-filters
+ f.write(pack('<l', len(self.filters)))
+
+ if len(self.filters) > 0:
+ # Then each filter directly, with a header describing
+ # their lengths.
+ headerpos = f.tell()
+ headerfmt = '<' + 'Q'*(len(self.filters))
+ f.write('.' * calcsize(headerfmt))
+ filter_sizes = []
+ for filter in self.filters:
+ begin = f.tell()
+ filter.tofile(f)
+ filter_sizes.append(f.tell() - begin)
+
+ f.seek(headerpos)
+ f.write(pack(headerfmt, *filter_sizes))
+
+ @classmethod
+ def fromfile(cls, f):
+ """Deserialize the ScalableBloomFilter in file object `f'."""
+ filter = cls()
+ filter._setup(*unpack(cls.FILE_FMT, f.read(calcsize(cls.FILE_FMT))))
+ nfilters, = unpack('<l', f.read(calcsize('<l')))
+ if nfilters > 0:
+ header_fmt = '<' + 'Q'*nfilters
+ bytes = f.read(calcsize(header_fmt))
+ filter_lengths = unpack(header_fmt, bytes)
+ for fl in filter_lengths:
+ filter.filters.append(BloomFilter.fromfile(f, fl))
+ else:
+ filter.filters = []
+
+ return filter
+
def __len__(self):
"""Returns the total number of elements stored in this SBF"""
return sum([f.count for f in self.filters])
View
@@ -1,5 +1,9 @@
import os
import doctest
+import unittest
+import random
+import tempfile
+from pybloom import BloomFilter, ScalableBloomFilter
from unittest import TestSuite
def additional_tests():
@@ -8,4 +12,30 @@ def additional_tests():
suite = TestSuite([doctest.DocTestSuite('pybloom.pybloom')])
if os.path.exists(readme_fn):
suite.addTest(doctest.DocFileSuite(readme_fn, module_relative=False))
- return suite
+ return suite
+
+class Serialization(unittest.TestCase):
+ SIZE = 12345
+ EXPECTED = set([random.randint(0, 10000100) for _ in xrange(SIZE)])
+
+ def test_serialization(self):
+ for klass, args in [(BloomFilter, (self.SIZE,)),
+ (ScalableBloomFilter, ())]:
+ filter = klass(*args)
+ for item in self.EXPECTED:
+ filter.add(item)
+
+ # It seems bitarray is finicky about the object being an
+ # actual file, so we can't just use StringIO. Grr.
+ f = tempfile.TemporaryFile()
+ filter.tofile(f)
+ del filter
+
+ f.seek(0)
+ filter = klass.fromfile(f)
+
+ for item in self.EXPECTED:
+ self.assert_(item in filter)
+
+if __name__ == '__main__':
+ unittest.main()

0 comments on commit 764f18f

Please sign in to comment.