# Implement Spatial Pooler

Данная тетрадь содержит задачу реализации Spatial Pooler'а.

Для начала посмотри эпизоды 0-8 видео гайда [HTMSchool](https://www.youtube.com/watch?v=XMB0ri4qgwc&list=PL3yXMgtrZmDqhsFQzwUC9V8MeeVOQ7eZ9).

## 01. Getting ready

Данная секция содержит:

- [опционально] установка `htm.core`
- импорт необходимых пакетов (убедись, что все они установлены)
- загрузка датасета

### HTM.Core

Если у тебя не установлен пакет `htm.core`, раскомментируй и запусти следующую ячейку. В случае проблем, обратись к официальной странице пакета на [гитхабе](https://github.com/htm-community/htm.core#installation) и проверь требуемые зависимости.

In [None]:
# !python -m pip install -i https://test.pypi.org/simple/ htm.core

In [None]:
# import random
import numpy as np
# import sys
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output
# from time import sleep

from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from scipy.sparse import csr_matrix

from htm.bindings.algorithms import SpatialPooler
from htm.bindings.sdr import SDR, Metrics

%matplotlib inline
    
seed = 1337

### Load data

Следующая ячейка загружает датасет MNIST (займет порядка 10-20 сек).

In [None]:
def load_ds(name, num_test, shape=None):
    """ 
    fetch dataset from openML.org and split to train/test
    @param name - ID on openML (eg. 'mnist_784')
    @param num_test - num. samples to take as test
    @param shape - new reshape of a single data point (ie data['data'][0]) as a list. Eg. [28,28] for MNIST
    """
    data = fetch_openml(name, version=1)
    sz=data['target'].shape[0]

    X = data['data']
    if shape is not None:
        new_shape = shape.insert(0, sz)
        X = np.reshape(X, shape)

    y = data['target'].astype(np.int32)
    # split to train/test data
    train_labels = y[:sz-num_test]
    train_images = X[:sz-num_test]
    test_labels  = y[sz-num_test:]
    test_images  = X[sz-num_test:]

    return train_labels, train_images, test_labels, test_images


def shuffle_data(x, y):
    indices = np.arange(len(y))
    np.random.shuffle(indices)
    x, y = np.array(x), np.array(y)
    return x[indices], y[indices]


train_labels, train_images, test_labels, test_images = load_ds('mnist_784', 10000, shape=[28,28])

np.random.seed(seed)
train_images, train_labels = shuffle_data(train_images, train_labels)
test_images, test_labels = shuffle_data(test_images, test_labels)

n_train_samples = train_images.shape[0]
n_test_samples = test_images.shape[0]
image_shape = train_images[0].shape
image_side = image_shape[0]
image_size = image_side ** 2


train_images.shape, train_labels.shape, test_images.shape, test_labels.shape

Пример формата данных датасета

In [None]:
plt.imshow(train_images[0])
print(f'Label: {train_labels[0]}')
print(f'Image shape: {image_shape}')
print(f'Image middle row: {train_images[0][image_side//2]}')

Перекодируем датасет в бинарные изображения и дальше будем работать с бинарными данными.

In [None]:
def plot_flatten_image(flatten_image, image_height):
    plt.imshow(flatten_image.reshape((image_height, -1)))

def to_binary_flatten_images(images, n_samples):
    # flatten every image to vector
    images = images.reshape((n_samples, -1))
    # binary encoding: each image pixel is encoded either 0 or 1 depending on that image mean value
    images = (images >= images.mean(axis=1, keepdims=True)).astype(np.int8)
    return images


train_images = to_binary_flatten_images(train_images)
test_images = to_binary_flatten_images(test_images)
plot_flatten_image(train_images[0])

## 02. Baseline: classifier on raw input

In [None]:
%%time

def test_bare_classification(x_tr,  y_tr, x_tst, y_tst):
    linreg = LogisticRegression(tol=.001, max_iter=100, multi_class='multinomial', penalty='l2', solver='lbfgs', n_jobs=3)
    linreg.fit(x_tr, y_tr)
    
    score = linreg.predict(flatten_images(x_tst)) == y_tst
    score = score.mean()
    print('Score:', 100 * score, '%')
    return score

n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

test_bare_classification(x_tr, y_tr, x_tst, y_tst)

## 03. Spatial Pooler: skeleton

In [None]:
class AbstractSpatialPooler:
    def __init__(self, input_size, output_size):
        self.input_size = input_size
        self.output_size = output_size
        
    def compute(self, dense_sdr, learn):
        return np.nonzero(sdr)[0]
    
    def _update_permanence(self, sdr, rows, cols):
        ...
        
    def _update_activity_boost(self, rows, cols):
        ...
        
    def _update_overlap_boost(self, sdr, rows, cols, overlaps):
        ...
        

np.random.seed(seed)
sp = AbstractSpatialPooler(train_images[0].shape, train_images[0].shape)
sparse_sdr = sp.compute(train_images[0], True)

print(sparse_sdr.size, sp.output_size)
assert sparse_sdr.size < sp.output_size
sparse_sdr

## 04. Train/test SP performance

In [None]:
%%time

def pretrain_sp(sp, images, n_samples):
    for img in images[:n_samples]:
        sp.compute(img, True)
    
def encode_to_csr_with_sp(images, sp, learn):
    flatten_encoded_sdrs = []
    indptr = [0]
    for img in images:
        encoded_sparse_sdr = sp.compute(img, learn)
        flatten_encoded_sdrs.extend(encoded_sparse_sdr)
        indptr.append(len(flatten_encoded_sdrs))

    data = np.ones(len(flatten_encoded_sdrs))
    csr = csr_matrix((data, flatten_encoded_sdrs, indptr), shape=(images.shape[0], sp.output_size))
    return csr

def test_classification_with_sp(x_tr,  y_tr, x_tst, y_tst, sp):
    enc = SDR(sp.input_shape)
    columns = SDR(sp.output_shape)

    # a small pretrain SP before real work
    pretrain_sp(sp, x_tr, n_samples=1000)
    
    # encode images and continuously train SP
    csr = encode_to_csr_with_sp(x_tr, sp, learn=True)
    
    # train linreg
    linreg = LogisticRegression(tol=.001, max_iter=100, multi_class='multinomial', penalty='l2', solver='lbfgs', n_jobs=3)
    linreg.fit(csr, y_tr)
    
    # encode test images (without SP learning) and then test score
    csr = encode_to_csr_with_sp(x_tst, sp, False)
    score = linreg.predict(csr) == y_tst
    score = score.mean()
    print('Score:', 100 * score, '% for n =', len(x_tr))
    return score

n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

my_sp = AbstractSpatialPooler(train_images[0].shape, train_images[0].shape)
test_classification_with_sp(x_tr, y_tr, x_tst, y_tst, my_sp)

## 04. Spatial Pooler: learning

In [None]:
class BasicSpatialPooler(AbstractSpatialPooler):
    def __init__(
        self, input_shape, output_shape, 
        permanence_threshold, sparsity_level, synapse_permanence_deltas, min_activation_threshold
    ):
        super().__init__(input_shape, output_shape)
        
        self.sparsity_level = sparsity_level
        # todo
        self.n_active_bits = int(self.output_size * sparsity_level)
        
        self.permanence_threshold = permanence_threshold
        self.synapse_permanence_increment, self.synapse_permanence_decrement = synapse_permanence_deltas
        self.min_activation_threshold = min_activation_threshold
        
        # initialization
        # todo all
        self.joint_shape = output_shape + input_shape
        self.receptive_fields = np.random.choice(2, size=self.joint_shape, p=[.2, .8])
        self.connections_permanence = np.random.uniform(size=self.joint_shape) * self.receptive_fields
        
        # remove
        self.dp = np.empty(input_shape, dtype=np.float)
        
    def compute(self, dense_sdr, learn):
        dense_sdr = dense_sdr.astype(np.bool)
        active_cells = self.connections_permanence[:, :, dense_sdr] >= self.permanence_threshold
        overlaps = (np.count_nonzero(active_cells, -1)).ravel()
        activated_indices = np.argpartition(-overlaps, self.n_active_bits)[:self.n_active_bits]
        activated_indices = activated_indices[overlaps[activated_indices] >= self.min_activation_threshold]
        
        if learn:
            rows, cols = np.unravel_index(activated_indices, self.output_shape)
            self._update_permanence(dense_sdr, rows, cols)

        return activated_indices
    
    def _update_permanence(self, sdr, rows, cols):
        dp = self.dp
        dp[sdr] = self.synapse_permanence_increment
        dp[~sdr] = -self.synapse_permanence_decrement
        perm = self.connections_permanence[rows, cols]
        perm = np.clip(perm + dp * self.receptive_fields[rows, cols], 0, 1)

        
np.random.seed(seed)
sp = BasicSpatialPooler(
    train_images[0].shape, (10, 10),
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .02),
    min_activation_threshold=4
)
sparse_sdr = sp.compute(train_images[0], True)

print(sparse_sdr.size, sp.output_size, sp.n_active_bits)
assert sparse_sdr.size == sp.n_active_bits
sparse_sdr

In [None]:
%%time

n = 1000
x_tr, y_tr = train_images[:n], train_labels[:n]
x_tst, y_tst = test_images[:n], test_labels[:n]

sp = BasicSpatialPooler(
    train_images[0].shape, (30, 30),
    permanence_threshold=.5,
    sparsity_level=.04,
    synapse_permanence_deltas=(.1, .03),
    min_activation_threshold=4
)
test_classification_with_sp(x_tr, y_tr, x_tst, y_tst, sp)

In [None]:
def encode_img(img):
    return (img >= img.mean()).astype(np.int8)

sample = train_images[0]
sample = encode_img(sample)

class MySpatialPooler:
    def __init__(self, input_shape, output_shape, permanence_threshold, sparsity_level, syn_perm_deltas, min_activation_threshold=1, max_boost_factor=1.5, boost_sliding_window=(1000, 1000)):
        assert isinstance(input_shape, tuple) and isinstance(output_shape, tuple)
        self.input_shape = input_shape
        self.output_shape = output_shape
        self.joint_shape = output_shape + input_shape
        self.output_size = output_shape[0] * output_shape[1]
        
        self.sparsity_level = sparsity_level
        self.n_active_bits = int(self.output_size * sparsity_level)
        
        self.permanence_threshold = permanence_threshold
        self.syn_perm_inc, self.syn_perm_dec = syn_perm_deltas
        self.min_activation_threshold = min_activation_threshold
        
        self.max_boost_factor = max_boost_factor
        self.activity_duty_cycle, self.overlap_duty_cycle = boost_sliding_window
        
        # init 
        self.receptive_fields = np.random.choice(2, size=self.joint_shape, p=[.2, .8])
        self.connections_permanence = np.random.uniform(size=self.joint_shape) * self.receptive_fields
        self.time_avg_activity = np.full(self.output_shape, self.sparsity_level, dtype=np.float)
        self.time_avg_overlap = np.ones(self.output_shape, dtype=np.float)
        self.dp = np.empty(input_shape, dtype=np.float)
        self.boost = self._compute_boost()
        
    def compute(self, x, learn):
        x = x.astype(np.bool)
        active_cells = self.connections_permanence[:, :, x] >= self.permanence_threshold
        overlaps = (np.count_nonzero(active_cells, -1) * self.boost).ravel()
        activated_indices = np.argpartition(-overlaps, self.n_active_bits)[:self.n_active_bits]
        activated_indices = activated_indices[overlaps[activated_indices] >= self.min_activation_threshold]
        
        if learn:
            rows, cols = np.unravel_index(activated_indices, self.output_shape)
            self._update_permanence(x, rows, cols)
            self._update_activity_boost(rows, cols)
#             self._update_overlap_boost(x, rows, cols, overlaps)

        return activated_indices
    
    def _update_permanence(self, x, rows, cols):
        dp = self.dp
        dp[x] = self.syn_perm_inc
        dp[~x] = -self.syn_perm_dec
        perm = self.connections_permanence[rows, cols]
        perm = np.clip(perm + dp * self.receptive_fields[rows, cols], 0, 1)
        
    def _update_activity_boost(self, rows, cols):
        self.time_avg_activity *= (self.activity_duty_cycle - 1) / self.activity_duty_cycle
        self.time_avg_activity[rows, cols] += 1 / self.activity_duty_cycle
        self.boost = self._compute_boost()
        
    def _update_overlap_boost(self, x, rows, cols, overlaps):
        self.time_avg_overlap += (overlaps.reshape(self.output_shape) - self.time_avg_overlap) / self.overlap_duty_cycle
        k = int(.05 * self.output_size)
        to_boost_indices = np.argpartition(self.time_avg_overlap.ravel(), k)[:k]
        to_boost_indices = np.unravel_index(to_boost_indices, self.output_shape)
        to_boost = self.connections_permanence[to_boost_indices]
        to_boost = np.clip(to_boost + .1 * self.permanence_threshold, 0, 1)
        
    def _compute_boost(self):
        return np.exp(-self.max_boost_factor * (self.time_avg_activity - self.time_avg_activity.mean()))
        

np.random.seed(1337)
my_sp = MySpatialPooler(train_images[0].shape, (10, 10), .5, .04, (.1, .02), 4)
my_sp.compute(sample, True)