In [None]:
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output

from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from scipy.sparse import csr_matrix

from htm.bindings.algorithms import SpatialPooler
from htm.bindings.sdr import SDR, Metrics

%matplotlib inline
    
seed = 1337

### Load data

Следующая ячейка загружает датасет MNIST (займет порядка 10-20 сек).

In [None]:
def load_ds(name, num_test, shape=None):
    """ 
    fetch dataset from openML.org and split to train/test
    @param name - ID on openML (eg. 'mnist_784')
    @param num_test - num. samples to take as test
    @param shape - new reshape of a single data point (ie data['data'][0]) as a list. Eg. [28,28] for MNIST
    """
    data = fetch_openml(name, version=1)
    sz=data['target'].shape[0]

    X = data['data']
    if shape is not None:
        new_shape = shape.insert(0, sz)
        X = np.reshape(X, shape)

    y = data['target'].astype(np.int32)
    # split to train/test data
    train_labels = y[:sz-num_test]
    train_images = X[:sz-num_test]
    test_labels  = y[sz-num_test:]
    test_images  = X[sz-num_test:]

    return train_labels, train_images, test_labels, test_images


def shuffle_data(x, y):
    indices = np.arange(len(y))
    np.random.shuffle(indices)
    x, y = np.array(x), np.array(y)
    return x[indices], y[indices]


train_labels, train_images, test_labels, test_images = load_ds('mnist_784', 10000, shape=[28,28])

np.random.seed(seed)
train_images, train_labels = shuffle_data(train_images, train_labels)
test_images, test_labels = shuffle_data(test_images, test_labels)

n_train_samples = train_images.shape[0]
n_test_samples = test_images.shape[0]
image_shape = train_images[0].shape
image_side = image_shape[0]
image_size = image_side ** 2


train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

Пример формата данных датасета

In [None]:
plt.imshow(train_images[0])
print(f'Label: {train_labels[0]}')
print(f'Image shape: {image_shape}')
print(f'Image middle row: {train_images[0][image_side//2]}')

Перекодируем датасет в бинарные изображения и дальше будем работать с бинарными данными.

In [None]:
def plot_flatten_image(flatten_image, image_height=28):
    plt.imshow(flatten_image.reshape((image_height, -1)))

def to_binary_flatten_images(images):
    n_samples = images.shape[0]
    # flatten every image to vector
    images = images.reshape((n_samples, -1))
    # binary encoding: each image pixel is encoded either 0 or 1 depending on that image mean value
    images = (images >= images.mean(axis=1, keepdims=True)).astype(np.int8)
    return images


train_images = to_binary_flatten_images(train_images)
test_images = to_binary_flatten_images(test_images)
plot_flatten_image(train_images[0])

## 02. Baseline: classifier on raw input

In [None]:
%%time

def test_bare_classification(x_tr,  y_tr, x_tst, y_tst):
    linreg = LogisticRegression(tol=.001, max_iter=100, multi_class='multinomial', penalty='l2', solver='lbfgs', n_jobs=3)
    linreg.fit(x_tr, y_tr)
    
    score = linreg.predict(x_tst) == y_tst
    score = score.mean()
    print('Score:', 100 * score, '%')
    return score

n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

# 87.3; 888ms
test_bare_classification(x_tr, y_tr, x_tst, y_tst)

## 03. Spatial Pooler: skeleton

In [None]:
class NoOpSpatialPooler:
    def __init__(self, input_size):
        self.input_size = input_size
        self.output_size = input_size
        
    def compute(self, dense_sdr, learn):
        return np.nonzero(dense_sdr)[0]
        

np.random.seed(seed)
sp = NoOpSpatialPooler(train_images[0].size)
sparse_sdr = sp.compute(train_images[0], True)

print(sparse_sdr.size, sp.output_size)
assert sparse_sdr.size < sp.output_size

## 04. Train/test SP performance aux pipeline

In [None]:
%%time

def pretrain_sp(sp, images, n_samples):
    for img in images[:n_samples]:
        sp.compute(img, True)
    
def encode_to_csr_with_sp(images, sp, learn):
    flatten_encoded_sdrs = []
    indptr = [0]
    for img in images:
        encoded_sparse_sdr = sp.compute(img, learn)
        flatten_encoded_sdrs.extend(encoded_sparse_sdr)
        indptr.append(len(flatten_encoded_sdrs))

    data = np.ones(len(flatten_encoded_sdrs))
    csr = csr_matrix((data, flatten_encoded_sdrs, indptr), shape=(images.shape[0], sp.output_size))
    return csr

def test_classification_with_sp(x_tr,  y_tr, x_tst, y_tst, sp):
    # a small pretrain SP before real work
    pretrain_sp(sp, x_tr, n_samples=1000)
    
    # encode images and continuously train SP
    csr = encode_to_csr_with_sp(x_tr, sp, learn=True)
    
    # train linreg
    linreg = LogisticRegression(tol=.001, max_iter=100, multi_class='multinomial', penalty='l2', solver='lbfgs', n_jobs=3)
    linreg.fit(csr, y_tr)
    
    # encode test images (without SP learning) and then test score
    csr = encode_to_csr_with_sp(x_tst, sp, False)
    score = linreg.predict(csr) == y_tst
    score = score.mean()
    print('Score:', 100 * score, '% for n =', len(x_tr))
    return score

n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]
my_sp = NoOpSpatialPooler(train_images[0].size)

# 87.3; 1.16s
test_classification_with_sp(x_tr, y_tr, x_tst, y_tst, my_sp)

In [None]:
%%time

np.random.seed(seed)

n = 100000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]
sp = NoOpSpatialPooler(train_images[0].size)

test_classification_with_sp(x_tr, y_tr, x_tst, y_tst, sp)

## Dense Spatial Pooler

In [None]:
class DenseSpatialPooler:
    def __init__(
        self, input_size, output_size,
        permanence_threshold, sparsity_level, synapse_permanence_deltas, min_activation_threshold=1, potential_synapses_p=.8,
        max_boost_factor=1.5, boost_sliding_window=(1000, 1000)
    ):
        self.input_size = input_size
        self.output_size = output_size
        self.joint_shape = (output_size, input_size)
        
        self.sparsity_level = sparsity_level
        self.n_active_bits = int(self.output_size * sparsity_level)
        
        self.permanence_threshold = permanence_threshold
        self.syn_perm_inc, self.syn_perm_dec = synapse_permanence_deltas
        self.min_activation_threshold = min_activation_threshold
        
        self.max_boost_factor = max_boost_factor
        self.activity_duty_cycle, self.overlap_duty_cycle = boost_sliding_window
        
        # init 
        self.receptive_fields = np.random.choice(2, size=self.joint_shape, p=[1-potential_synapses_p, potential_synapses_p])
        self.connections_permanence = np.random.uniform(size=self.joint_shape) * self.receptive_fields
        self.time_avg_activity = np.full(self.output_size, self.sparsity_level, dtype=np.float)
        self.time_avg_overlap = np.ones(self.output_size, dtype=np.float)
        self.dp = np.empty(input_size, dtype=np.float)
        self.boost = self._compute_boost()
        
    def compute(self, dense_sdr, learn):
        dense_sdr = dense_sdr.astype(np.bool)
        active_cells = self.connections_permanence[:, dense_sdr] >= self.permanence_threshold
        overlaps = np.count_nonzero(active_cells, -1) * self.boost
        
        activated_cols = np.argpartition(-overlaps, self.n_active_bits)[:self.n_active_bits]
        activated_cols = activated_cols[overlaps[activated_cols] >= self.min_activation_threshold]
        
        if learn:
            self._update_permanence(dense_sdr, activated_cols)
            self._update_activity_boost(activated_cols)
#             self._update_overlap_boost(dense_sdr, activated_cols, overlaps)

        return activated_cols
    
    def _update_permanence(self, dense_sdr, activated_cols):
        dp = self.dp
        dp[dense_sdr] = self.syn_perm_inc
        dp[~dense_sdr] = -self.syn_perm_dec
        perm = self.connections_permanence[activated_cols]
        perm = np.clip(perm + dp * self.receptive_fields[activated_cols], 0, 1)
        
    def _update_activity_boost(self, activated_cols):
        self.time_avg_activity *= (self.activity_duty_cycle - 1) / self.activity_duty_cycle
        self.time_avg_activity[activated_cols] += 1 / self.activity_duty_cycle
        self.boost = self._compute_boost()
        
    def _update_overlap_boost(self, x, rows, cols, overlaps):
        self.time_avg_overlap += (overlaps - self.time_avg_overlap) / self.overlap_duty_cycle
        k = int(.05 * self.output_size)
        to_boost_indices = np.argpartition(self.time_avg_overlap, k)[:k]
        to_boost = self.connections_permanence[to_boost_indices]
        to_boost = np.clip(to_boost + .1 * self.permanence_threshold, 0, 1)
        
    def _compute_boost(self):
        return np.exp(-self.max_boost_factor * (self.time_avg_activity - self.time_avg_activity.mean()))
        

np.random.seed(seed)
my_sp = DenseSpatialPooler(train_images[0].size, 10**2, .5, .04, (.1, .02), 4, potential_synapses_p=.8)
my_sp.compute(train_images[0], True)

In [None]:
%%time

np.random.seed(seed)
n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = DenseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=30**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=4,
    max_boost_factor=3
)
# 84.0; 3.24 s
pretrain_sp(sp, x_tr, n)

In [None]:
%%time

np.random.seed(seed)
n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = DenseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=50**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=4,
    max_boost_factor=3
)
# 84.0; 3.24 s
pretrain_sp(sp, x_tr, n)

In [None]:
%%time

np.random.seed(seed)
n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = DenseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=30**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=4,
    max_boost_factor=3
)
# 84.0; 3.24 s
test_classification_with_sp(x_tr, y_tr, x_tst, y_tst, sp)

## Dense SP Optimized

In [None]:
a = np.arange(12).reshape((3, 4))
np.minimum(a[[1, 2]], 6, out=a[[1,2]])
a

In [None]:
class OptimizedDenseSpatialPooler:
    def __init__(
        self, input_size, output_size, sparsity_level=.04,
        permanence_threshold=.5, min_activation_threshold=1, synapse_permanence_deltas=(.1, .02), potential_synapses_p=.8,
        max_boost_factor=1.5, activity_duty_cycle=1000, overlap_low_bound_pct=.25, overlap_duty_cycle=1000
    ):
        self.input_size = input_size
        self.output_size = output_size
        self.joint_shape = (output_size, input_size)
        
        self.sparsity_level = sparsity_level
        self.n_active_bits = int(self.output_size * sparsity_level)
        
        self.permanence_threshold = permanence_threshold
        self.syn_perm_inc, self.syn_perm_dec = synapse_permanence_deltas
        self.min_activation_threshold = min_activation_threshold
        
        self.max_boost_factor = max_boost_factor
        self.activity_duty_cycle = activity_duty_cycle
        self.boost_ma_decay = (self.activity_duty_cycle - 1) / self.activity_duty_cycle
        
        self.overlap_low_bound_pct = overlap_low_bound_pct
        self.overlap_duty_cycle = overlap_duty_cycle
        self.overlap_duty_cycle_period = max(100, int(np.sqrt(overlap_duty_cycle)))
        self.overlap_ma_decay = (self.overlap_duty_cycle - 1) / self.overlap_duty_cycle
        self.timestamp = 0
        
        # init 
        self.receptive_fields = np.random.choice(2, size=self.joint_shape, p=[1-potential_synapses_p, potential_synapses_p]).astype(np.bool)
        self.connections_permanence = np.random.uniform(size=self.joint_shape) * self.receptive_fields
        self.connections_activity = self.connections_permanence.T >= self.permanence_threshold
        
        self.time_avg_activity = np.full(self.output_size, self.sparsity_level, dtype=np.float)
        self.time_avg_activity_mean = self.time_avg_activity.mean()
        self.boost = self._compute_boost()
        
        self.time_avg_overlap = np.zeros(self.output_size, dtype=np.float)
        self.time_avg_overlap_mean = self.time_avg_overlap.mean()
        
        # cache
        self.dp = np.full(input_size, -self.syn_perm_dec, dtype=np.float)
        self.dps = np.full((self.n_active_bits, input_size), -self.syn_perm_dec, dtype=np.float)
        
    def compute(self, dense_sdr, learn):
        sparse_sdr = np.flatnonzero(dense_sdr)
        active_cells = self.connections_activity[sparse_sdr]
        overlaps = np.count_nonzero(active_cells, 0) * self.boost
        
        activated_cols = np.argpartition(-overlaps, self.n_active_bits)[:self.n_active_bits]
        activated_cols = activated_cols[overlaps[activated_cols] >= self.min_activation_threshold]
        
        if learn:
            self.timestamp += 1
            self._update_permanence(sparse_sdr, activated_cols)
            self._update_activity_boost(activated_cols)
            self._update_overlap_boost(activated_cols, overlaps)

        return activated_cols
    
    def _update_permanence(self, sparse_sdr, activated_cols):
        self.dp[sparse_sdr] = self.syn_perm_inc
        
        dps = self.dps
        dps[:] = self.connections_permanence[activated_cols]
        dps[:, sparse_sdr] = np.maximum(dps[:, sparse_sdr], 0.)
        dps += self.dp * self.receptive_fields[activated_cols]
        dps[:, sparse_sdr] = np.minimum(dps[:, sparse_sdr], 1.)
        
        self.connections_permanence[activated_cols] = dps
        
        self.dp[sparse_sdr] = -self.syn_perm_dec
        self.connections_activity[:, activated_cols] = dps.T >= self.permanence_threshold
        
    def _update_activity_boost(self, activated_cols):
        decay = self.boost_ma_decay
        
        self.time_avg_activity *= decay
        self.time_avg_activity[activated_cols] += 1. - decay
        self.time_avg_activity_mean = decay * self.time_avg_activity_mean + (1 - decay) * activated_cols.size / self.output_size
        self.boost = self._compute_boost()
        
    def _update_overlap_boost(self, activated_cols, overlaps):
        decay = self.overlap_ma_decay
        self.time_avg_overlap = decay * self.time_avg_overlap + (1 - decay) * overlaps
        self.time_avg_overlap_mean = decay * self.time_avg_overlap_mean + (1 - decay) * overlaps.mean()
        
        if self.timestamp % self.overlap_duty_cycle_period == 0:
            boosting_mask = self.time_avg_overlap < self.overlap_low_bound_pct * self.time_avg_overlap_mean
            boosting_indices = np.flatnonzero(boosting_mask)
            
            if boosting_indices.any():
                print('+')
                np.minimum(
                    self.connections_permanence[boosting_indices] + self.syn_perm_inc * self.receptive_fields[boosting_indices],
                    1,
                    out=self.connections_permanence[boosting_indices]
                )
                self.connections_activity[:, boosting_indices] = self.connections_permanence[boosting_indices].T >= self.permanence_threshold
        
    def _compute_boost(self):
        return np.exp(-self.max_boost_factor * (self.time_avg_activity - self.time_avg_activity_mean))
        

np.random.seed(seed)
my_sp = OptimizedDenseSpatialPooler(train_images[0].size, 10**2)
my_sp.compute(train_images[0], True)

In [None]:
%%time

np.random.seed(seed)
n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = OptimizedDenseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=30**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=2,
    max_boost_factor=3,
    potential_synapses_p=.5
)
# 84.0; 3.24 s
test_classification_with_sp(x_tr, y_tr, x_tst, y_tst, sp)

In [None]:
%%time

np.random.seed(seed)
n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = OptimizedDenseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=30**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=2,
    max_boost_factor=3,
    potential_synapses_p=.5
)
# 84.0; 3.24 s
pretrain_sp(sp, x_tr, n)

In [None]:
%%time

np.random.seed(seed)
n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = OptimizedDenseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=50**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=2,
    max_boost_factor=3,
    potential_synapses_p=.5
)
# 84.0; 3.24 s
pretrain_sp(sp, x_tr, n)

In [None]:
%%time

np.random.seed(seed)
n = 100000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = OptimizedDenseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=30**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=2,
    max_boost_factor=3,
    potential_synapses_p=.5
)
# 84.0; 3.24 s
test_classification_with_sp(x_tr, y_tr, x_tst, y_tst, sp)

In [None]:
%%time

np.random.seed(seed)
n = 100000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = OptimizedDenseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=50**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=2,
    max_boost_factor=3,
    potential_synapses_p=.5
)
# 84.0; 3.24 s
test_classification_with_sp(x_tr, y_tr, x_tst, y_tst, sp)

In [None]:
%%time

np.random.seed(seed)
n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = OptimizedDenseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=65**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=4,
    max_boost_factor=3,
    potential_synapses_p=.15
)
# 84.0; 3.24 s
pretrain_sp(sp, x_tr, n)

In [None]:
%%time

np.random.seed(seed)
n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = OptimizedDenseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=65**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=4,
    max_boost_factor=3,
    potential_synapses_p=.15
)
test_classification_with_sp(x_tr, y_tr, x_tst, y_tst, sp)

## Sparse Spatial Pooler

In [None]:
from collections import defaultdict, Counter

class SparseSpatialPooler:
    def __init__(
        self, input_size, output_size, permanence_threshold=.5, sparsity_level=.04, potenrial_synapses_p=.8,
        synapse_permanence_deltas=(.1, .02), min_activation_threshold=1,
    ):
        self.input_size = input_size
        self.output_size = output_size
        self.joint_shape = (output_size, input_size)
        
        self.sparsity_level = sparsity_level
        self.n_active_bits = int(self.output_size * sparsity_level)
        
        self.permanence_threshold = permanence_threshold
        self.syn_perm_inc, self.syn_perm_dec = synapse_permanence_deltas
        self.synapse_permanence_deltas = np.array(synapse_permanence_deltas)
        self.min_activation_threshold = min_activation_threshold
        
        forward_connections = [
            self._make_random_connections(self.output_size, p=potenrial_synapses_p)
            for input_bit in range(self.input_size)
        ]
        self.backward_connections = self._to_backward_connections(self.output_size, forward_connections)
        self.active_forward_connections = [
            connections[permanences >= self.permanence_threshold]
            for connections, permanences in forward_connections
        ]
        
        self.overlaps = np.zeros(self.output_size, dtype=np.int)
    
    @staticmethod
    def _make_random_connections(size, p):
        connections_mask = np.random.binomial(1, p, size=size)
        connections = np.flatnonzero(connections_mask)
        permanences = np.random.uniform(size=connections.size)

        return connections, permanences
    
    @staticmethod
    def _to_backward_connections(size, forward_connections):
        presynaptic_connections = [[] for _ in range(size)]
        presynaptic_permanences = [[] for _ in range(size)]
        
        for input_bit, (connections, permanences) in enumerate(forward_connections):
            for output_bit, permanence in zip(connections, permanences):
                presynaptic_connections[output_bit].append(input_bit)
                presynaptic_permanences[output_bit].append(permanence)
                
        return [
            (np.array(connections), np.array(permanences))
            for connections, permanences in zip(presynaptic_connections, presynaptic_permanences)
        ]
    
    def _get_active_columns(self, sparse_sdr):
        self.overlaps[:] = 0
        for input_bit in sparse_sdr:
            self.overlaps[self.active_forward_connections[input_bit]] += 1

        activated_cols = np.argpartition(-self.overlaps, self.n_active_bits)[:self.n_active_bits]
        if self.min_activation_threshold > 0:
            min_activation_mask = self.overlaps[activated_cols] >= self.min_activation_threshold
            activated_cols = activated_cols[min_activation_mask]
            
        return activated_cols
            
    def compute(self, dense_sdr, learn):
        sparse_sdr = np.flatnonzero(dense_sdr)
        dense_sdr = dense_sdr.astype(np.bool)
        
        activated_cols = self._get_active_columns(sparse_sdr)
        if learn:
            self._update_permanence(dense_sdr, sparse_sdr, activated_cols)

        return activated_cols

    def _update_permanence(self, dense_sdr, sparse_sdr, activated_cols):
        to_change = defaultdict(list)
        
        for output_bit in activated_cols:
            connections, permanences = self.backward_connections[output_bit]

            indices = np.flatnonzero(dense_sdr[connections])
            not_indices = np.flatnonzero(1 - dense_sdr[connections])
            
            to_disconnect_candidates = permanences[not_indices]
            to_disconnect = np.logical_and(
                to_disconnect_candidates >= self.permanence_threshold,
                to_disconnect_candidates < self.permanence_threshold + self.syn_perm_dec
            )
            to_disconnect = connections[not_indices[np.flatnonzero(to_disconnect)]]
            
            to_connect_candidates = permanences[indices]
            to_connect = np.logical_and(
                to_connect_candidates < self.permanence_threshold,
                to_connect_candidates >= self.permanence_threshold - self.syn_perm_inc
            )
            to_connect = connections[indices[np.flatnonzero(to_connect)]]
        
            for input_bit in to_disconnect:
                to_change[input_bit].append(-1 - output_bit)
            for input_bit in to_connect:
                to_change[input_bit].append(output_bit)
            
            np.maximum(permanences[indices], 0, out=permanences[indices])
            permanences -= self.syn_perm_dec
            permanences[indices] += self.syn_perm_dec + self.syn_perm_inc            
            np.minimum(permanences[indices], 1, out=permanences[indices])
        
#         for input_bit, changes in to_change.items():
#             cols = set(self.active_forward_connections[input_bit])
#             for x in changes:
#                 if x >= 0:
#                     cols.add(x)
#                 else:
#                     x = -x - 1
#                     cols.remove(x)
#             self.active_forward_connections[input_bit] = np.array(list(sorted(cols)))

np.random.seed(1337)
# my_sp = SparseSpatialPooler(10, 8, potenrial_synapses_p=.4)
my_sp = SparseSpatialPooler(train_images[0].size, 10**2, potenrial_synapses_p=.4)
my_sp.compute(train_images[0], True)

In [None]:
a = np.random.choice(10, size=(6, 3))
b = np.arange(10) * -2

print(a)
print(b[a])

In [None]:
a = np.array([[0, 1, 1, 0], [1, 0, 1, 0]])
b = np.array([[1, 1, 0, 0], [0, 1, 0, 0]])
print(a)
print(b)
print(a != b)
np.nonzero(a!=b)

In [None]:
from collections import Counter

class OptimizedSparseSpatialPooler:
    def __init__(
        self, input_size, output_size, sparsity_level=.04,
        permanence_threshold=.5, synapse_permanence_deltas=(.1, .02), potential_synapses_p=None,
        max_boost_factor=1.5, activity_duty_cycle=1000
    ):
        self.input_size = input_size
        self.output_size = output_size
        self.joint_shape = (output_size, input_size)
        
        self.sparsity_level = sparsity_level
        self.n_active_bits = int(self.output_size * sparsity_level)
        
        if potential_synapses_p is None:
            potential_synapses_p = .05 + (20000 - np.clip(output_size, 2000, 20000)) / 18000. * (.25 - .05)
        assert(.05 <= potential_synapses_p <= .35)
        self.potential_synapses_p = potential_synapses_p
        
        self.permanence_threshold = permanence_threshold
        self.syn_perm_inc, self.syn_perm_dec = synapse_permanence_deltas
        
        self.max_boost_factor = max_boost_factor
        self.activity_duty_cycle = activity_duty_cycle
        self.boost_ma_decay = (self.activity_duty_cycle - 1) / self.activity_duty_cycle
        
        # init
        # (output_size, ~input_size)
        self.connections = self._random_choice_noreplace_repeated(input_size, int(input_size * potential_synapses_p), output_size)        
        self.connections_permanence = np.random.uniform(size=self.connections.shape)
        
        active_connections, active_connections_lengths, active_connections_max_len = self._get_active_connections()
        self.active_connections = active_connections
        self.active_connections_lengths = active_connections_lengths
        self.active_connections_max_len = active_connections_max_len
        
        self.time_avg_activity = np.full(self.output_size, self.sparsity_level, dtype=np.float)
        self.time_avg_activity_mean = self.time_avg_activity.mean()
        self.boost = self._compute_boost()
        
        # cache
        self._active_synapses = None
        self.dp = np.full(input_size, -self.syn_perm_dec, dtype=np.float)
        self.dps = np.full((self.n_active_bits, self.connections.shape[1]), -self.syn_perm_dec, dtype=np.float)
        
    @staticmethod
    def _random_choice_noreplace_repeated(max_val, n_samples, n_times):
        result = np.empty((n_times, n_samples), dtype=np.int)
        for i in range(n_times):
            result[i] = np.random.choice(max_val, size=n_samples, replace=False)
            result[i].sort()
        return result
    
    def _get_active_connections(self):
        max_len = int(self.output_size * self.potential_synapses_p * 1.25 / 2)
        result = np.empty((self.input_size, max_len), dtype=np.int)
        lengths = np.zeros(self.input_size, dtype=np.int)
        
        mask = self.connections_permanence >= self.permanence_threshold
        for ob in range(self.output_size):
            active_connections = self.connections[ob, mask[ob]]
            for ib in active_connections:
                if lengths[ib] < max_len:
                    result[ib, lengths[ib]] = ob
                    lengths[ib] += 1
        return result, lengths, max_len
        
    def compute(self, dense_sdr, learn):
        sparse_sdr = np.flatnonzero(dense_sdr)
        active_columns = self._get_active_columns(sparse_sdr)        
        if learn:
            self._update_permanence(sparse_sdr, active_columns)
            self._update_activity_boost(active_columns)
            pass

        return active_columns

    def _get_active_columns(self, sparse_sdr):
        active_synapses = np.concatenate([
            self.active_connections[ib][:self.active_connections_lengths[ib]]
            for ib in sparse_sdr
        ])
        negative_overlaps = np.bincount(active_synapses, minlength=self.output_size) * -self.boost

        active_columns = np.argpartition(negative_overlaps, self.n_active_bits)[:self.n_active_bits]
        return active_columns
    
    def _update_permanence(self, sparse_sdr, active_columns):
        origin_active_status = (self.connections_permanence[active_columns] >= self.permanence_threshold)
        
        self.dp[sparse_sdr] = self.syn_perm_inc
        dps = self.dps
        dps[:] = self.connections_permanence[active_columns]
        dps += self.dp[self.connections[active_columns]]
        np.clip(dps, 0., 1., out=dps)
        self.connections_permanence[active_columns] = dps
        self.dp[sparse_sdr] = -self.syn_perm_dec
        
        changed_active_status = np.nonzero(
            origin_active_status != (self.connections_permanence[active_columns] >= self.permanence_threshold)
        )
        
        obs, poss = changed_active_status
        ibs, to_remove, to_add = set(), defaultdict(list), defaultdict(list)
        for col, ob, pos in zip(obs, active_columns[obs], poss):
            ib = self.connections[ob, pos]
            ibs.add(ib)
            if origin_active_status[col, pos]:
                to_remove[ib].append(ob)
            else:
                to_add[ib].append(ob)

        for ib in ibs:
            n = self.active_connections_lengths[ib]
            
            if to_remove[ib]:
                lr = len(to_remove[ib])
                mask = np.isin(self.active_connections[ib, :n], to_remove[ib], assume_unique=True, invert=True)
                self.active_connections[ib, :n-lr] = self.active_connections[ib, :n][mask]
                n -= lr
                                
            if to_add[ib]:
                la = len(to_add[ib])
                if n + la < self.active_connections_max_len:
                    self.active_connections[ib, n:n+la] = to_add[ib]
                    n += la
                else:
                    for ob in to_add:
                        if n < self.active_connections_max_len:
                            self.active_connections[ib, n] = ob
                            n += 1
            self.active_connections_lengths[ib] = n
            
        self.dp[sparse_sdr] = self.syn_perm_inc
        
    def _update_activity_boost(self, active_columns):
        decay = self.boost_ma_decay
        
        self.time_avg_activity *= decay
        self.time_avg_activity[active_columns] += 1. - decay
        self.time_avg_activity_mean = decay * self.time_avg_activity_mean + (1 - decay) * active_columns.size / self.output_size
        self.boost = self._compute_boost()
        
    def _compute_boost(self):
        return np.exp(-self.max_boost_factor * (self.time_avg_activity - self.time_avg_activity_mean))
        

np.random.seed(seed)
my_sp = OptimizedSparseSpatialPooler(train_images[0].size, 10**2)
my_sp.compute(train_images[0], True)

In [None]:
np.random.seed(seed)

a = np.random.choice(10, size=(6, 3))
b = np.random.choice(2, size=(6, 3)).astype(np.bool)
c = np.random.choice(2, size=10).astype(np.bool)

print(a)
print(b)
print(c)
print(a[b])
print(a * b)
print(c[a] * b)

In [None]:
from collections import Counter

class OptimizedSparseSpatialPooler:
    def __init__(
        self, input_size, output_size, sparsity_level=.04,
        permanence_threshold=.5, synapse_permanence_deltas=(.1, .02), potential_synapses_p=None,
        max_boost_factor=1.5, activity_duty_cycle=1000
    ):
        self.input_size = input_size
        self.output_size = output_size
        self.joint_shape = (output_size, input_size)
        
        self.sparsity_level = sparsity_level
        self.n_active_bits = int(self.output_size * sparsity_level)
        
        if potential_synapses_p is None:
            potential_synapses_p = .05 + (20000 - np.clip(output_size, 2000, 20000)) / 18000. * (.25 - .05)
        assert(.05 <= potential_synapses_p <= .35)
        self.potential_synapses_p = potential_synapses_p
        
        self.permanence_threshold = permanence_threshold
        self.syn_perm_inc, self.syn_perm_dec = synapse_permanence_deltas
        
        self.max_boost_factor = max_boost_factor
        self.activity_duty_cycle = activity_duty_cycle
        self.boost_ma_decay = (self.activity_duty_cycle - 1) / self.activity_duty_cycle
        
        # init
        # (output_size, ~input_size)
        self.connections = self._random_choice_noreplace_repeated(input_size, int(input_size * potential_synapses_p), output_size)        
        self.connections_permanence = np.random.uniform(size=self.connections.shape)
        
#         active_connections, active_connections_lengths, active_connections_max_len = self._get_active_connections()
#         self.active_connections = active_connections
#         self.active_connections_lengths = active_connections_lengths
#         self.active_connections_max_len = active_connections_max_len
        self.active_connections_status = self.connections_permanence >= self.permanence_threshold
        
        self.time_avg_activity = np.full(self.output_size, self.sparsity_level, dtype=np.float)
        self.time_avg_activity_mean = self.time_avg_activity.mean()
        self.boost = self._compute_boost()
        
        # cache
        self._active_synapses = None
        self.dp = np.full(input_size, -self.syn_perm_dec, dtype=np.float)
        self.dps = np.full((self.n_active_bits, self.connections.shape[1]), -self.syn_perm_dec, dtype=np.float)
        
    @staticmethod
    def _random_choice_noreplace_repeated(max_val, n_samples, n_times):
        result = np.empty((n_times, n_samples), dtype=np.int)
        for i in range(n_times):
            result[i] = np.random.choice(max_val, size=n_samples, replace=False)
            result[i].sort()
        return result
    
    def _get_active_connections(self):
        max_len = int(self.output_size * self.potential_synapses_p * 1.25 / 2)
        result = np.empty((self.input_size, max_len), dtype=np.int)
        lengths = np.zeros(self.input_size, dtype=np.int)
        
        mask = self.connections_permanence >= self.permanence_threshold
        for ob in range(self.output_size):
            active_connections = self.connections[ob, mask[ob]]
            for ib in active_connections:
                if lengths[ib] < max_len:
                    result[ib, lengths[ib]] = ob
                    lengths[ib] += 1
        return result, lengths, max_len
        
    def compute(self, dense_sdr, learn):
        sparse_sdr = np.flatnonzero(dense_sdr)
        active_columns = self._get_active_columns2(sparse_sdr, dense_sdr)        
        if learn:
            self._update_permanence(sparse_sdr, active_columns)
            self._update_activity_boost(active_columns)
            pass

        return active_columns

    def _get_active_columns2(self, sparse_sdr, dense_sdr):
        negative_overlaps = -np.count_nonzero(dense_sdr[self.connections] * self.active_connections_status, axis=-1)
#         active_synapses = np.concatenate([
#             self.active_connections[ib][:self.active_connections_lengths[ib]]
#             for ib in sparse_sdr
#         ])
#         negative_overlaps = np.bincount(active_synapses, minlength=self.output_size) * -self.boost

        active_columns = np.argpartition(negative_overlaps, self.n_active_bits)[:self.n_active_bits]
        return active_columns

    def _get_active_columns(self, sparse_sdr, dense_sdr):
        active_synapses = np.concatenate([
            self.active_connections[ib][:self.active_connections_lengths[ib]]
            for ib in sparse_sdr
        ])
        negative_overlaps = np.bincount(active_synapses, minlength=self.output_size) * -self.boost

        active_columns = np.argpartition(negative_overlaps, self.n_active_bits)[:self.n_active_bits]
        return active_columns
    
    def _update_permanence(self, sparse_sdr, active_columns):
#         origin_active_status = self.active_connections_status[active_columns]
        
        self.dp[sparse_sdr] = self.syn_perm_inc
        dps = self.dps
        dps[:] = self.connections_permanence[active_columns]
        dps += self.dp[self.connections[active_columns]]
        np.clip(dps, 0., 1., out=dps)
        self.connections_permanence[active_columns] = dps
        self.dp[sparse_sdr] = -self.syn_perm_dec
        
        self.active_connections_status[active_columns] = self.connections_permanence[active_columns] >= self.permanence_threshold
#         changed_active_status = np.nonzero(origin_active_status != self.active_connections_status[active_columns])
        
#         obs, poss = changed_active_status
#         ibs, to_remove, to_add = set(), defaultdict(list), defaultdict(list)
#         for col, ob, pos in zip(obs, active_columns[obs], poss):
#             ib = self.connections[ob, pos]
#             ibs.add(ib)
#             if origin_active_status[col, pos]:
#                 to_remove[ib].append(ob)
#             else:
#                 to_add[ib].append(ob)

#         for ib in ibs:
#             n = self.active_connections_lengths[ib]
            
#             if to_remove[ib]:
#                 lr = len(to_remove[ib])
#                 mask = np.isin(self.active_connections[ib, :n], to_remove[ib], assume_unique=True, invert=True)
#                 self.active_connections[ib, :n-lr] = self.active_connections[ib, :n][mask]
#                 n -= lr
                                
#             if to_add[ib]:
#                 la = len(to_add[ib])
#                 if n + la < self.active_connections_max_len:
#                     self.active_connections[ib, n:n+la] = to_add[ib]
#                     n += la
#                 else:
#                     for ob in to_add:
#                         if n < self.active_connections_max_len:
#                             self.active_connections[ib, n] = ob
#                             n += 1
#             self.active_connections_lengths[ib] = n
            
        self.dp[sparse_sdr] = self.syn_perm_inc
        
    def _update_activity_boost(self, active_columns):
        decay = self.boost_ma_decay
        
        self.time_avg_activity *= decay
        self.time_avg_activity[active_columns] += 1. - decay
        self.time_avg_activity_mean = decay * self.time_avg_activity_mean + (1 - decay) * active_columns.size / self.output_size
        self.boost = self._compute_boost()
        
    def _compute_boost(self):
        return np.exp(-self.max_boost_factor * (self.time_avg_activity - self.time_avg_activity_mean))
        

np.random.seed(seed)
my_sp = OptimizedSparseSpatialPooler(train_images[0].size, 10**2)
my_sp.compute(train_images[0], True)

In [None]:
%%time

np.random.seed(seed)
n = 3000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = OptimizedSparseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=65**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    max_boost_factor=3,
    potential_synapses_p=.15
)
# 84.0; 3.24 s
pretrain_sp(sp, x_tr, n)

In [None]:
%%time

np.random.seed(seed)
n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = OptimizedSparseSpatialPooler(
    input_size=train_images[0].size, 
    output_size=65**2,
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    max_boost_factor=3,
    potential_synapses_p=None
)
test_classification_with_sp(x_tr, y_tr, x_tst, y_tst, sp)

___

## TESTING