In [13]:
import itertools
import re
import math
import numpy as np
import random
import scipy.sparse as sparse
import time
from tqdm import trange, tqdm
from bloom_filter import BloomFilter
import lshash.lshash as lsh
from IPython.display import display, clear_output

# Read in regexes

In [2]:
with open('random_strings.txt', 'r') as file:
    regexes = [line.rstrip('\n') for line in file]
print(len(regexes))

306000


In [3]:
NUM_REGEX_CUTOFF = 1000 # for developement speed
regexes = regexes[:NUM_REGEX_CUTOFF]
num_regexes = len(regexes)
print(num_regexes)

1000


In [4]:
compiled_regexes = [re.compile(re.escape(regex)) for regex in regexes]

In [5]:
def string_to_ngram_set(text, n_lengths=[3]):
    '''Takes in a string and parses it to 2,3, and 4-grams'''
    for n in n_lengths:
        for i in range(len(text) - n + 1):
            yield text[i:i + n] # adds 2, 3, and 4-grams

# Configure bloom filters

In [6]:
# Create a bloom filter for each ngram, see if I can match it on a doc
ARR_LENGTH = 1024

def configure_bloom(some_set, max_elements=ARR_LENGTH, error_rate=0.1):
    b = BloomFilter(max_elements=max_elements, error_rate=error_rate)
    for obj in some_set:
        b.add(obj)
    return np.asarray(b.backend.array_)

test_arr = configure_bloom(set())
ARR_BIT_LENGTH = test_arr.nbytes * 8
print(test_arr)
print(f"Array bit length: {ARR_BIT_LENGTH}")

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0]
Array bit length: 9856


## Create LSH bins

In [26]:
def to_bit_array(x):
    return np.unpackbits(x.view(np.uint8))

def to_sparse_array(x):
    return sparse.dok_matrix(x)

def set_rep(string):
    '''Creates a set representation of a bloom filter'''
    trigrams = string_to_ngram_set(string)
    bf = sparse.dok_matrix(
            np.atleast_2d(to_bit_array(configure_bloom(trigrams)))
        )
    return {key[1] for key in bf.keys()}
    
def sparsify(regexes):
    bf_arrays = (configure_bloom(string_to_ngram_set(regex)) for regex in regexes)
    return [sparse.dok_matrix(np.atleast_2d(to_bit_array(arr))) for arr in bf_arrays]

def sparsify_single(string):
    return sparse.dok_matrix(np.atleast_2d(to_bit_array(configure_bloom(string_to_ngram_set(string)))))

def construct_lsh(num_bins, data):
    bits_needed = math.ceil(math.log2(num_bins))
    bins = lsh.LSHash(bits_needed, data.shape[1])
    for i, d in tqdm(enumerate(data)):
        bins.index(d, extra_data=i) # save index
    
    # Create dictionary of pooled indices
    tables = bins.hash_tables[0]
    keys = list(tables.keys())
    bin_dict = {key: [] for key in keys}
    for key in keys:
        for tup in tables.get_list(key):
            bin_dict[key].append(tup[1])
    return bins, bin_dict

def regex_to_bit_array(regex):
    bf = configure_bloom(string_to_ngram_set(regex))
    return to_bit_array(bf)

def regexes_to_bit_arrays(regexes):
    bf_arrays = [configure_bloom(string_to_ngram_set(regex)) for regex in regexes]
    bit_arrays = np.vstack([to_bit_array(arr) for arr in bf_arrays])
    return bit_arrays

def convert_regexes(regexes):
    return np.vstack([regex_to_bit_array(r) for r in regexes])

def sparse_intersection(pair):
    return pair[0].multiply(pair[1]).count_nonzero()

In [35]:
sparse_arrs = sparsify(regexes)
print(len(sparse_arrs))

1000


In [32]:
sets = [set_rep(regex) for regex in regexes]
len(sets)

1000

### Create lagoons (pools of pools)

In [42]:
SIZE_TO_EVAL = 5 # when a pool contains this or fewer regexes, evaluate the regexes
MIN_INTERSECTION = 5 # min number of bits on which arrays should intersect in order to pool

class Lagoon:
    def __init__(self, pools, pools_in_row, num_rows_in_pool):
        self.pools = pools
        self.pool_members = pools_in_row
        self.pool_counts = num_rows_in_pool
        self.nrows = len(pools)

In [33]:
def matches(query, pool_array):
    # Every key in pool_array is present in query
    for key in pool_array:
        if key not in query:
            return False
    return True

def get_closest_candidates(arr, bins, min_intersection, num_results):
    results = bins.query(arr, num_results=num_results, distance_func="hamming")
    results = [result[1] for result in results] # extract indices # TODO: filter distance here as well?
    if len(results) > 1:
        return set(results[1:])
    else:
        return None

def construct_lagoon(arrs, sets):
    '''arrs: tuple of dok row matrices
       sets: list of '''
    # Construct LSH bins and extract binned indices
    nrows = len(arrs)
    num_bins = nrows // 5
    
    # Convert tuple of DOKs to numpy array
    # TODO: hold on to original np array so this doesn't have to be done?
    arrs_np = np.vstack([dok.toarray() for dok in arrs])
    
    # Get bins & set up lagoon attributes
    bins, bin_dict = construct_lsh(num_bins, arrs_np)
    pool_members = []; pools = []; pool_counts = []
    
    def intersection_size(set1, set2):
        return len(set1.intersection(set2))
    
    # Pool the highest bit intersection from each bin
    for key, index_list in bin_dict.items():
        clear_output()
        pairs = itertools.combinations(index_list, 2)
        if len(index_list) > 1:              
            print("Pair:", pairs)
            m1, m2 = max(pairs, key = lambda p: intersection_size(sets[p[1]], sets[p[2]]))
            overlap = intersection_size(sets[m1], sets[m2])
            display(f"Closest pair: {(m1, m2)}, intersection {overlap}")
        
            if overlap >= MIN_INTERSECTION:
                index_list.remove(m1)
                index_list.remove(m2)
                pool_members.append((m1, m2))
                pooled_set = sets[m1].intersection(sets[m2])
                pools.append(pooled_set) # Allocate memory ahead of time?
                pool_counts.append(2)
        
        # Add remaining singletons to pool
        for idx in index_list:
            pools.append(sets[idx])
            pool_members.append([idx])
            pool_counts.append(1) 

    return Lagoon(pools, pool_members, pool_counts)


def construct_next_lagoon(previous):
    lagoon = construct_lagoon(previous.pools)
    
    for i, members in enumerate(lagoon.pool_members):
        lagoon.pool_counts[i] = sum([previous.pool_counts[member] for member in members])
    return lagoon

def create_lagoon_list(bf, max_length=3):
    lagoons = [construct_lagoon(bf)]
    while len(lagoons) < max_length:
        next_lagoon = construct_next_lagoon(lagoons[-1])
        # No more progress can be made if # rows stays the same
        if len(next_lagoon.pools) == len(lagoons[-1].pools):
            break
        lagoons.append(next_lagoon)
    lagoons.reverse()
    return lagoons

In [44]:
def construct_next_lagoon_sp(previous):
    lagoon = construct_lagoon(previous.pools)
    
    for i, members in enumerate(lagoon.pool_members):
        lagoon.pool_counts[i] = sum([previous.pool_counts[member] for member in members])
        arr = lagoon.pools[i]  # Reset the members from previous pool's array to the XOR
        for row in members:
            # The XOR in this case is child - parent, since parent is a subset of the child array
            previous.pools[row] = previous.pools[row] - arr
    return lagoon

def create_lagoon_list_sp(bf, max_length=3):
    lagoons = [construct_lagoon(bf)]
    while len(lagoons) < max_length:
        next_lagoon = construct_next_lagoon(lagoons[-1])
        # No more progress can be made if # rows stays the same
        if len(next_lagoon.pools) == len(lagoons[-1].pools): # change to nrows
            break
        lagoons.append(next_lagoon)
    lagoons.reverse()
    return lagoons

In [36]:
construct_lagoon(sparse_arrs, sets)

Pair: <itertools.combinations object at 0x7fb6e11c2f50>


IndexError: tuple index out of range

In [46]:
start = time.time()
lagoons = create_lagoon_list(sparse_arrs, max_length=8)
end = time.time()
print(end - start)

15.988216876983643


In [47]:
start = time.time()
lagoons2 = create_lagoon_list_sp(sparse_arrs, max_length=8)
end = time.time()
print(end - start)

17.59111499786377


In [48]:
print(lagoons2[0].pools)

(<1x9856 sparse matrix of type '<class 'numpy.uint8'>'
	with 12 stored elements in Dictionary Of Keys format>, <1x9856 sparse matrix of type '<class 'numpy.uint8'>'
	with 12 stored elements in Dictionary Of Keys format>, <1x9856 sparse matrix of type '<class 'numpy.uint8'>'
	with 12 stored elements in Dictionary Of Keys format>, <1x9856 sparse matrix of type '<class 'numpy.uint8'>'
	with 12 stored elements in Dictionary Of Keys format>, <1x9856 sparse matrix of type '<class 'numpy.uint8'>'
	with 7 stored elements in Dictionary Of Keys format>, <1x9856 sparse matrix of type '<class 'numpy.uint8'>'
	with 21 stored elements in Dictionary Of Keys format>, <1x9856 sparse matrix of type '<class 'numpy.uint8'>'
	with 11 stored elements in Dictionary Of Keys format>, <1x9856 sparse matrix of type '<class 'numpy.uint8'>'
	with 52 stored elements in Dictionary Of Keys format>, <1x9856 sparse matrix of type '<class 'numpy.uint8'>'
	with 40 stored elements in Dictionary Of Keys format>, <1x9856 sp

## Querying the lagoon

In [78]:
def sparse_match(a, b):   
    for key in b.keys():
        if key not in a:
            return False
    return True

def query_lagoon(query, lagoon_list):
    bases = []
    row_checks = 0
    search_next = list(range(len(lagoon_list[0].pools)))
    for i, lagoon in enumerate(lagoon_list):
        row_checks += len(search_next)
        candidates = [j for j in search_next if sparse_match(query, lagoon.pools[j])]
        search_next = []
        for candidate in candidates:
            if lagoon.pool_counts[candidate] > SIZE_TO_EVAL: # Gets the rows to check on the next lagoon
                search_next += [row for row in lagoon.pool_members[candidate]]
            else:
                bases += get_base_regexes([candidate], lagoon_list[i:]) # Recurses down lagoon to get bases
    return bases, row_checks

def get_base_regexes(rows, lagoon_list):
    next_rows = rows
    for lagoon in lagoon_list:
        next_rows = [r for row in next_rows for r in lagoon.pool_members[row]]
    return next_rows

        
def regexes_matching_string(query_string, lagoons):
    '''Resolves which regexes match the query_string by first
       prefiltering with the lagoons'''
    # TODO: refactor sparsify into 2 funcs so I can get the DOK directly here
#     q_arr = sparsify([query_string])[0]
    q_arr = sparsify_single(query_string)
    bases, row_checks = query_lagoon(q_arr, lagoons)
    matched_regexes = []
    for idx in bases:
        if compiled_regexes[idx].search(query_string):
            matched_regexes.append(idx)
    return {'matches': matched_regexes, 'regexes_checked': len(bases), 'rows_checked': row_checks}

def regexes_matching_string_uncompiled(query_string, lagoons):
    '''Resolves which regexes match the query_string by first
       prefiltering with the lagoons'''
#     q_arr = sparsify([query_string])[0]
    q_arr = sparsify_single(query_string)
    bases, row_checks = query_lagoon(q_arr, lagoons)
    matched_regexes = []
    for idx in bases:
        if re.search(re.escape(regexes[idx]), query_string):
            matched_regexes.append(idx)
    return {'matches': matched_regexes, 'regexes_checked': len(bases), 'rows_checked': row_checks}


In [79]:
test_sizes = [10, 100, 1000]

def test_compiled(num_tests):
    samples = random.sample(regexes, num_tests)
    texts = samples

    start = time.time()
    re_results = []
    for j in tqdm(range(len(texts))):
        for i in range(len(regexes)):
            x = compiled_regexes[i].search(texts[j])
            if x:
                re_results.append(x)
    end = time.time()

    results = []
    start2 = time.time()
    for i in tqdm(range(len(texts))):
        results.append(regexes_matching_string(texts[i], lagoons))
    end2 = time.time()

    print(f"Brute force: {end - start}; Tree: {end2 - start2}; Factor: {(end-start) / (end2-start2)}")
    
    return {'results': results, 'brute_time': end - start, 'drag_time': end2-start2}

test_results = [test_compiled(num) for num in test_sizes]

100%|██████████| 10/10 [00:00<00:00, 2813.65it/s]
100%|██████████| 10/10 [00:00<00:00, 247.42it/s]
100%|██████████| 100/100 [00:00<00:00, 3399.72it/s]
 31%|███       | 31/100 [00:00<00:00, 301.17it/s]

Brute force: 0.0054051876068115234; Tree: 0.041686058044433594; Factor: 0.12966415776349202


100%|██████████| 100/100 [00:00<00:00, 315.38it/s]
 38%|███▊      | 378/1000 [00:00<00:00, 3772.39it/s]

Brute force: 0.030543804168701172; Tree: 0.3183140754699707; Factor: 0.09595492792332594


100%|██████████| 1000/1000 [00:00<00:00, 3827.28it/s]
100%|██████████| 1000/1000 [00:02<00:00, 361.84it/s]

Brute force: 0.2630631923675537; Tree: 2.7651498317718506; Factor: 0.09513523981410739





In [23]:
for num, result in zip(test_sizes, test_results):
    reg = result['results'][0]['regexes_checked']
    rows = result['results'][0]['rows_checked']
    print("Num tests:", num, ":", "Brute time:", result['brute_time'], "Drag time:", result['drag_time'], "Row checks:", rows, "Direct evaluations: ", reg)

Num tests: 10 : Brute time: 0.004829883575439453 Drag time: 0.029380321502685547 Row checks: 389 Direct evaluations:  1
Num tests: 100 : Brute time: 0.027478694915771484 Drag time: 0.2505209445953369 Row checks: 389 Direct evaluations:  2
Num tests: 1000 : Brute time: 0.23934316635131836 Drag time: 2.4370667934417725 Row checks: 420 Direct evaluations:  73


### Test on sparsified lagoon

In [80]:
test_sizes = [10, 100, 1000]

def test_compiled(num_tests):
    samples = random.sample(regexes, num_tests)
    texts = samples

    start = time.time()
    re_results = []
    for j in tqdm(range(len(texts))):
        for i in range(len(regexes)):
            x = compiled_regexes[i].search(texts[j])
            if x:
                re_results.append(x)
    end = time.time()

    results = []
    start2 = time.time()
    for i in tqdm(range(len(texts))):
        results.append(regexes_matching_string(texts[i], lagoons2))
    end2 = time.time()

    print(f"Brute force: {end - start}; Tree: {end2 - start2}; Factor: {(end-start) / (end2-start2)}")
    
    return {'results': results, 'brute_time': end - start, 'drag_time': end2-start2}

test_results = [test_compiled(num) for num in test_sizes]

100%|██████████| 10/10 [00:00<00:00, 2540.16it/s]
100%|██████████| 10/10 [00:00<00:00, 237.58it/s]
100%|██████████| 100/100 [00:00<00:00, 3623.21it/s]
 34%|███▍      | 34/100 [00:00<00:00, 331.26it/s]

Brute force: 0.00572514533996582; Tree: 0.04370594024658203; Factor: 0.1309923847345567


100%|██████████| 100/100 [00:00<00:00, 324.59it/s]
 37%|███▋      | 369/1000 [00:00<00:00, 3689.41it/s]

Brute force: 0.02910590171813965; Tree: 0.3092920780181885; Factor: 0.09410490532003869


100%|██████████| 1000/1000 [00:00<00:00, 3793.21it/s]
100%|██████████| 1000/1000 [00:02<00:00, 379.97it/s]

Brute force: 0.2648320198059082; Tree: 2.632758855819702; Factor: 0.10059106599166807





### Uncompiled

In [81]:

def test_uncompiled(num_tests):
    samples = random.sample(regexes, num_tests)
    texts = samples

    start = time.time()
    re_results = []
    for j in tqdm(range(len(texts))):
        for i in range(len(regexes)):
            x = re.search(re.escape(regexes[i]), texts[j])
            if x:
                re_results.append(x)
    end = time.time()

    results = []
    start2 = time.time()
    for i in tqdm(range(len(texts))):
        results.append(regexes_matching_string_uncompiled(texts[i], lagoons))
    end2 = time.time()

    print(f"Brute force: {end - start}; Tree: {end2 - start2}; Factor: {(end-start) / (end2-start2)}")
    
    return {'results': results, 'brute_time': end - start, 'drag_time': end2-start2}

test_results_unc = [test_uncompiled(num) for num in test_sizes]



100%|██████████| 10/10 [00:00<00:00, 11.12it/s]
100%|██████████| 10/10 [00:00<00:00, 239.97it/s]
  2%|▏         | 2/100 [00:00<00:09, 10.54it/s]

Brute force: 0.9062891006469727; Tree: 0.04251599311828613; Factor: 21.316427870461236


100%|██████████| 100/100 [00:09<00:00, 11.00it/s]
100%|██████████| 100/100 [00:00<00:00, 290.01it/s]
  0%|          | 2/1000 [00:00<01:25, 11.62it/s]

Brute force: 9.089063882827759; Tree: 0.3457050323486328; Factor: 26.291384366258573


 35%|███▍      | 349/1000 [00:35<01:06,  9.79it/s]


KeyboardInterrupt: 

In [82]:
for num, result in zip(test_sizes, test_results_unc):
    reg = result['results'][0]['regexes_checked']
    rows = result['results'][0]['rows_checked']
    print("Num tests:", num, ":", "Brute time:", result['brute_time'], "Drag time:", result['drag_time'], "Row checks:", rows, "Direct evaluations: ", reg)

Num tests: 10 : Brute time: 0.8677456378936768 Drag time: 0.041022300720214844 Row checks: 407 Direct evaluations:  42
Num tests: 100 : Brute time: 8.620949983596802 Drag time: 0.30213093757629395 Row checks: 417 Direct evaluations:  71
Num tests: 1000 : Brute time: 86.35677599906921 Drag time: 2.8943521976470947 Row checks: 419 Direct evaluations:  103


In [83]:
print(sum([res['results'][0]['rows_checked'] for res in test_results_unc]))
print(sum([res['results'][1]['rows_checked'] for res in test_results_unc]))
print(sum([res['results'][2]['rows_checked'] for res in test_results_unc]))
print(sum([res['results'][0]['regexes_checked'] for res in test_results_unc]))
print(sum([res['results'][1]['regexes_checked'] for res in test_results_unc]))
print(sum([res['results'][2]['regexes_checked'] for res in test_results_unc]))

1243
1186
1206
216
42
93


In [84]:
print(sum([test_results_unc[0]['results'][i]['regexes_checked'] for i in range(10)]))
print(sum([test_results_unc[1]['results'][i]['regexes_checked'] for i in range(100)]))
print(sum([test_results_unc[2]['results'][i]['regexes_checked'] for i in range(1000)]))
print(sum([test_results_unc[0]['results'][i]['rows_checked'] for i in range(10)]))
print(sum([test_results_unc[1]['results'][i]['rows_checked'] for i in range(100)]))
print(sum([test_results_unc[2]['results'][i]['rows_checked'] for i in range(1000)]))

240
3586
33927
3982
40264
402933


In [85]:
import cProfile, pstats, io
from pstats import SortKey
pr = cProfile.Profile()

In [86]:
test_sizes = [10, 100, 1000]

def test_compiled(num_tests):
    samples = random.sample(regexes, num_tests)
    texts = samples

    start = time.time()
    re_results = []
    for j in tqdm(range(len(texts))):
        for i in range(len(regexes)):
            x = compiled_regexes[i].search(texts[j])
            if x:
                re_results.append(x)
    end = time.time()

    results = []
    start2 = time.time()
    pr.enable()
    for i in tqdm(range(len(texts))):
        results.append(regexes_matching_string(texts[i], lagoons2))
    pr.disable()
    end2 = time.time()

    print(f"Brute force: {end - start}; Tree: {end2 - start2}; Factor: {(end-start) / (end2-start2)}")
    
    return {'results': results, 'brute_time': end - start, 'drag_time': end2-start2}

test_results = [test_compiled(num) for num in test_sizes]


100%|██████████| 10/10 [00:00<00:00, 3481.04it/s]
100%|██████████| 10/10 [00:00<00:00, 216.67it/s]
100%|██████████| 100/100 [00:00<00:00, 3237.94it/s]
 22%|██▏       | 22/100 [00:00<00:00, 214.30it/s]

Brute force: 0.004554033279418945; Tree: 0.04777193069458008; Factor: 0.09532864201227728


100%|██████████| 100/100 [00:00<00:00, 221.08it/s]
 38%|███▊      | 381/1000 [00:00<00:00, 3809.74it/s]

Brute force: 0.03206586837768555; Tree: 0.45335984230041504; Factor: 0.07072939723769661


100%|██████████| 1000/1000 [00:00<00:00, 3804.05it/s]
100%|██████████| 1000/1000 [00:04<00:00, 232.16it/s]

Brute force: 0.2642381191253662; Tree: 4.308779001235962; Factor: 0.06132552146433368





In [87]:
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())

         8258142 function calls (8257032 primitive calls) in 4.782 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1110    0.019    0.000    4.743    0.004 <ipython-input-78-6df9809fb359>:29(regexes_matching_string)
     1110    0.011    0.000    4.105    0.004 <ipython-input-69-ac640612672e>:11(sparsify_single)
2220/1110    0.020    0.000    3.649    0.003 /opt/miniconda3/envs/stats/lib/python3.7/site-packages/scipy/sparse/dok.py:74(__init__)
     2271    0.104    0.000    3.372    0.001 {method 'update' of 'dict' objects}
   140123    0.156    0.000    3.268    0.000 /opt/miniconda3/envs/stats/lib/python3.7/site-packages/scipy/sparse/_index.py:32(__getitem__)
   140123    0.223    0.000    2.448    0.000 /opt/miniconda3/envs/stats/lib/python3.7/site-packages/scipy/sparse/_index.py:126(_validate_indices)
   140123    0.447    0.000    1.561    0.000 /opt/miniconda3/envs/stats/lib/python3.7/site-packages/scipy/sparse

## Unused code

In [None]:
def jaccard_distance(A, B):
    return 1 - len(A.intersection(B)) / len(A.union(B))

    
def distance_matrix(X):
    '''Returns an upper-triangular matrix of hamming distances based on bit arrays'''
    n = len(X)
    distances = np.zeros((n,n)) # initialize 
    coords = np.triu_indices(n, 1, n)
    for i,j in tqdm(zip(coords[0], coords[1])):
        distances[i,j] = hamming_array_distance(X[i], X[j])
    return distances

def jaccard_array_distance(x, y):
    and_array = np.bitwise_and(x, y)
    or_array = np.bitwise_or(x, y)
    return 1 - one_bits(and_array) / one_bits(or_array)

def hamming_array_distance(x, y):
    '''Assumes x and y are equal length'''
    return one_bits(np.bitwise_xor(x, y))


def one_bits(x):
    '''Returns number of one bits in an array'''
    return sum([bin(entry).count("1") for entry in x])

def bit_intersection(x, y):
    return one_bits(np.bitwise_and(x, y))