In [1]:
import numpy as np

In [164]:
n_input_bits = 255
n_output_bits = 8
n_point_bits = 32
n_points = 1000
n_cluster_create_bits = 6
n_cluster_activate_bits = 4

In [165]:
def gen_input_output(count=1000, n_in_active_bits=40, n_out_active_bits=4) -> (np.array, np.array):
    inp = np.zeros(shape=(count, n_input_bits), dtype=np.int8)
    out = np.zeros(shape=(count, n_output_bits), dtype=np.int8)
    for i in range(count):
        in_bits = np.random.choice(n_input_bits, n_in_active_bits, replace=False)
        out_bits = np.random.choice(n_output_bits, n_out_active_bits, replace=False)
        inp[i][in_bits] = 1
        out[i][out_bits] = 1
    return inp, out

def gen_points() -> (np.array, np.array):
    point_ins = np.zeros(shape=(n_points, n_point_bits), dtype=np.int32)
    point_outs = np.zeros(n_points, dtype=np.int32)
    for i in range(n_points):
        point_ins[i] = np.random.choice(n_input_bits, n_point_bits, replace=False)
        point_outs[i] = np.random.choice(n_output_bits, 1)
    return point_ins, point_outs   
                

In [198]:
in_data, out_data = gen_input_output(count=1000, n_in_active_bits=50)

# print(in_data)
# print(out_data)

point_in_bits, point_out_bits = gen_points()

# print(point_in_bits)
# print(point_out_bits)

point_clusters = [np.empty(shape=(0, n_point_bits), dtype=np.int8) for _ in range(n_points)]
cluster_stats = [[] for _ in range(n_points)]

In [75]:
# filter active points for input bit vector
point_masks = in_data[0][point_in_bits]
len(np.where(np.sum(point_masks, axis=1) > n_cluster_activate_bits)[0])

618

In [188]:
# test visualisation

print(in_data[0][point_in_bits])
print(in_data[0])
print(point_in_bits[0])
print(np.where(in_data[0] > 0))
print(set(point_in_bits[0]) & set(np.where(in_data[0] > 0)[0]))
print(in_data[0][point_in_bits][0])

[[1 1 1 ..., 0 1 0]
 [0 0 0 ..., 0 0 0]
 [1 0 0 ..., 0 0 0]
 ..., 
 [0 0 0 ..., 1 0 0]
 [0 0 0 ..., 0 0 0]
 [1 0 1 ..., 1 0 0]]
[0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 1 0 1 0 0 1 0 1 0
 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0
 0 0 1 0 0 0 0 0 0 1 1 0 1 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 1 0 1 1
 0 0 0 1 0 0 0 0 1 0 0 0 1 1 0 0 1 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 1
 0 1 0 0 1 0 0 0 1 0 0 0 0 0 1 0 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1
 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1]
[128 236 151 118 114  83  53 146 248 122  50   9 107   6 194  44 111 252
  27 190 153 228 169 183 131 103  78  21  95 139 160 225]
(array([  9,  13,  16,  21,  27,  31,  49,  61,  63,  65,  67,  70,  72,
        76,  87,  93, 100, 113, 120, 121, 123, 128, 131, 141, 144, 146,
       147, 151, 156, 160, 161, 164, 166, 176, 180, 184, 186, 189, 193,
      

In [195]:
from numba import njit

def clear_clusters():
    global point_clusters, cluster_stats
    point_clusters = [np.empty(shape=(0, n_point_bits), dtype=np.int8) for _ in range(n_points)]
    cluster_stats = [[] for _ in range(n_points)] # [[{} {} ...] ...]

def cluster_count():
    global point_clusters
    return sum(len(clusters) for clusters in point_clusters)

def append_cluster(point_idx: int, mask: np.array):
    global point_clusters
    clusters = point_clusters[point_idx]
    isects = np.count_nonzero(clusters & mask, axis=1) # bit-intersection counts
    if len(isects) == 0 or np.max(isects) < np.count_nonzero(mask):
#         print('append ', mask)
        point_clusters[point_idx] = np.vstack((clusters, mask))
    
def update_clusters(point_idx: int, mask: np.array, active_bits: int):
    pass
    
def process_input(input_bits: np.array):
    # filter active points 
    point_masks = input_bits[point_in_bits]
    point_sums = np.sum(point_masks, axis=1)
    for point_idx in range(n_points):
        if point_sums[point_idx] >= n_cluster_create_bits:
            append_cluster(point_idx, point_masks[point_idx])
        update_clusters(point_idx, point_masks[point_idx])
    


In [199]:
%%time

# feed data
clear_clusters()

for bit_vector in in_data:
    process_input(bit_vector)

Wall time: 1min 3s


In [200]:
cluster_count()

626701