In [None]:
import numpy as np
import mmh3
from math import log, ceil 

In [65]:
class BloomFilter:
    def __init__(self, n, false_pos_rate):
        # optimal values determined by wikipedia equations
        self.bit_arr_size = ceil(-n * log(false_pos_rate) / (log(2)**2))
        self.num_hash_fns = ceil(self.bit_arr_size * log(2)/ n)
        print(f"Creating filter of size {self.bit_arr_size} with {self.num_hash_fns} hash fns")
        self.bit_arr = np.zeros(self.bit_arr_size, dtype=bool)
        
    def _get_bit_arr_idx(self, hash_algo_seed: int, row: np.ndarray) -> int:
        row_hash = mmh3.hash(row.tobytes(), hash_algo_seed) 
        return row_hash % self.bit_arr_size 

    def add(self, row: np.ndarray) -> None:
        """
        From wikipedia:
        To add an element, feed it to each of the k hash functions to get k array positions. Set the bits at all these positions to 1.
        """
        for hash_algo_seed in range(self.num_hash_fns):
            bit_idx = self._get_bit_arr_idx(hash_algo_seed, row)
            self.bit_arr[bit_idx] = True

    def __contains__(self, row: np.ndarray) -> bool:
        for seed in range(self.num_hash_fns):
            bit_idx = self._get_bit_arr_idx(seed, row)
            if not self.bit_arr[bit_idx]:
                return False
        return True

In [66]:
def deduplicate(data: np.array, false_pos_rate: float) -> np.array:
    n = data.shape[0]
    bf = BloomFilter(n, false_pos_rate)
    unique_rows = []
    num_false_positives = 0
    for idx, row in enumerate(data):
        if row in bf:
            # avoiding false positives by scanning
            prev_rows_matching_curr = np.all(data[:idx] == row, axis=1)
            row_truly_already_present = np.any(prev_rows_matching_curr)
            if row_truly_already_present:
                continue
            else:
                num_false_positives += 1
        # no false negatives, so if absent from bf then it's unique
        unique_rows.append(row)
        bf.add(row)
    print(f"False positive rate {num_false_positives / n}")
    return np.array(unique_rows)        

In [67]:
total_num_rows = 100_000
num_dups = 100
all_data = np.random.randint(0, 99, size=(total_num_rows, 7))
assert len(all_data) == len(np.unique(all_data, axis=0)), "Initial data has duplicates; rerun"
rows_to_duplicate = np.random.choice(total_num_rows-2, num_dups, replace=False)
all_data[rows_to_duplicate] = all_data[-1]
assert total_num_rows - len(np.unique(all_data, axis=0)) == num_dups

In [68]:
false_pos_rate = 0.01
deduped = deduplicate(all_data, false_pos_rate)
assert len(deduped) == total_num_rows - num_dups

Creating filter of size 958506 with 7 hash fns
False positive rate 0.00175
