# Sparse Hebbian Learning : reproducing SparseNet

In this notebook, we test the convergence of SparseNet as a function of different parameters tuning the quantization. These parameters only influence 



In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(precision=2, suppress=True)

In [2]:
from shl_scripts.shl_experiments import SHL

list_figures = ['show_dico', 'plot_variance',  'plot_variance_histogram',  'time_plot_prob',  'time_plot_kurt',  'time_plot_var']
DEBUG_DOWNSCALE, verbose = 10, 0
DEBUG_DOWNSCALE, verbose = 10, 100
DEBUG_DOWNSCALE, verbose = 1, 0
N_scan = 7

shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, verbose=verbose)
data = shl.get_data()

We will use the ``joblib`` package do distribute this computation on different CPUs.

In [3]:
from joblib import Parallel, delayed


## different rescaling values

In [21]:
def run(C, list_figures, data):
    matname = 'rescaling - C={}'.format(C)
    shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, 
              learning_algorithm='mp', C=C, verbose=verbose)
    dico = shl.learn_dico(data=data, matname=matname, list_figures=list_figures)
    return dico

In [22]:
[(C, []) for C in  np.linspace(0, 10, 5)]

[(0.0, []), (2.5, []), (5.0, []), (7.5, []), (10.0, [])]

In [23]:
Parallel(n_jobs=2)(delayed(np.sqrt)(i ** 2) for i in range(10))

[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]

In [None]:
out = Parallel(n_jobs=-1, verbose=15)(delayed(run)(C, [], data) for C in  np.linspace(0, 10, 5))


## different quantization parameters

In [None]:
2 ** np.arange(3, 9)

In [None]:
def run(nb_quant):
    matname = 'rescaling - nb_quant={}'.format(nb_quant)
    shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, 
              learning_algorithm='mp', nb_quant=nb_quant, verbose=verbose)
    dico = shl.learn_dico(data=data, matname=matname, list_figures=list_figures)
    return dico

In [None]:
out = Parallel(n_jobs=-1, verbose=15)(delayed(run)(nb_quant, [], data) for nb_quant in  2 ** np.arange(3, 9))


## Version used

In [None]:
import version_information
%version_information numpy, shl_scripts