# Sparse Hebbian Learning on MNIST

In [None]:
# %pip uninstall -y shl_scripts
# %pip install git+https://github.com/bicv/SparseHebbianLearning.git

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(precision=2, suppress=True)

In [None]:
%mkdir -p cache_data

Getting images from https://github.com/rasbt/mnist-pngs


In [None]:
from shl_scripts.shl_experiments import SHL

DEBUG_DOWNSCALE, verbose = 32, 10
DEBUG_DOWNSCALE, verbose = 1, 10
DEBUG_DOWNSCALE, verbose = 4, 10

list_figures = ['plot_variance',  'plot_variance_histogram',  'time_plot_prob',  'time_plot_kurt',  'time_plot_var', 'time_plot_MC']
homeo_method = 'HEH'

opts = dict(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, verbose=verbose, cache_dir='cache_data', datapath='database', name_database='mnist_train/1', height=28, width=28, patch_width=9, N_patches=2**16)
shl = SHL(homeo_method=homeo_method, **opts)
tag = 'MNIST'
data = shl.get_data(matname='mnist_train_1')
dico = shl.learn_dico(data=data, matname=tag + '_' + homeo_method, list_figures=list_figures)   

In [None]:
dico.dictionary.shape, shl.patch_width**2

In [None]:
%mkdir -p figures

In [None]:
fig, ax = shl.show_dico(dico, order=False)
plt.savefig('figures/shl_HEH.png')

In [None]:
fig, ax = shl.show_Pcum(dico)
plt.savefig('figures/shl_HEH_Pcum.png')

### control: learning with a simplified homeostasis

We build up a simpler heuristics based on the probability of activation of filters

In [None]:
shl = SHL(homeo_method='HAP', **opts)
dico = shl.learn_dico(data=data, matname=tag + '_HAP', list_figures=list_figures)

In [None]:
fig, ax = shl.show_dico(dico, order=False)
plt.savefig('figures/shl_HAP.png')

In [None]:
fig, ax = shl.show_Pcum(dico)
plt.savefig('figures/shl_HAP_Pcum.png')

### control: learning without homeostasis

During the learning, to avoid divergence, the norm of the filters is shunted to $1$.

In [None]:
shl = SHL(homeo_method='None', **opts)
dico = shl.learn_dico(data=data, matname=tag + '_nohomeo', list_figures=list_figures) 

In [None]:
fig, ax = shl.show_dico(dico, order=False)
plt.savefig('figures/shl_nohomeo.png')

In [None]:
fig, ax = shl.show_Pcum(dico)
plt.savefig('figures/shl_nohomeo_Pcum.png')

## Version used

In [None]:
%load_ext watermark
%watermark -i -h -m -v -p numpy,matplotlib,shl_scripts