In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(precision=4, suppress=True)

# Sparse Hebbian Learning: testing the tools in the package

This notebook aims at show-casing the different tools implemented in the package.


In [None]:
from shl_scripts.shl_experiments import SHL

DEBUG_DOWNSCALE, verbose = 1, 1

In [None]:
matname = 'test_tools'
shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, verbose=verbose)

In [None]:
help(shl)


## loading a database

Loading patches, with or without mask:

In [None]:
N_patches = 12
from shl_scripts.shl_tools import show_data

for i, (do_mask, label) in enumerate(zip([False, True], ['Without mask', 'With mask'])):
    data = SHL(DEBUG_DOWNSCALE=1, verbose=verbose, N_patches=N_patches, n_image=1, do_mask=do_mask).get_data()
    fig, axs = show_data(data)
    axs[0].set_ylabel(label)
    plt.show()

In [None]:
N_patches = 12
from shl_scripts.shl_tools import show_data
np.random.seed()
data_ = []
for i, (do_bandpass, label) in enumerate(zip([False, True], ['With no bandpass', 'With bandpass'])):
    data = SHL(seed=2018, verbose=verbose, N_patches=N_patches, n_image=1, do_bandpass=do_bandpass, over_patches=1).get_data()
    fig, axs = show_data(data)
    data_.append(data)
    axs[0].set_ylabel(label);
    plt.show()
fig, axs = show_data(data_[0]-data_[1])
axs[0].set_ylabel('Difference');
plt.show()    

Downscaling images to get a better signal to noise ratio:

In [None]:
for i, patch_ds in enumerate(2**np.arange(4)):
    data = SHL(DEBUG_DOWNSCALE=1, verbose=0, height=480//patch_ds, width=480//patch_ds, N_patches=N_patches, n_image=1, patch_ds=patch_ds).get_data()
    fig, axs = show_data(data)
    axs[0].set_ylabel('patch_ds='+str(patch_ds))
    plt.show();

In [None]:
for i, over_patches in enumerate(2**np.arange(5)):
    data = SHL(DEBUG_DOWNSCALE=1, verbose=0, N_patches=N_patches, n_image=1, over_patches=over_patches).get_data()
    fig, axs = show_data(data)
    axs[0].set_ylabel('patch_ds='+str(over_patches));
    plt.show();

Saving to a file:

In [None]:
data = shl.get_data(matname=matname)


## initializing the dictionary


In [None]:
from shl_scripts.shl_tools import ovf_dictionary
data = ovf_dictionary(N_patches, n_pixels=shl.patch_width**2)
fig, axs = show_data(data);


## caching tools : loading a database


In [None]:
matname = 'test_tools'
shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, verbose=verbose)

Let's first remove potential data cache:

In [None]:
!rm cache_dir/{matname}*

If the data cache does not exist (as a file), it creates it:

In [None]:
data = shl.get_data(matname=matname)

But if the data cache exists, it loads it:

In [None]:
data = shl.get_data(matname=matname)


## caching tools : learning, then reloading the dictionary


If the dictionary does not exist (as a file), it learns it:

In [None]:
list_figures = ['show_dico']
shl.n_iter = 129
dico = shl.learn_dico(data=data, matname=matname, list_figures=list_figures)

But if the dictionary exists, it loads it:

In [None]:
dico = shl.learn_dico(data=data, matname=matname, list_figures=list_figures)


## caching tools : resuming a learning

If we give a dictionary as an argument to the learning method, then we resume the learning from this dictionary and overwrite the data cache.


In [None]:
dico = shl.learn_dico(data=data, matname=matname, dictionary=dico.dictionary, list_figures=list_figures)


## caching tools : coding


In [None]:
%%time
sparse_code = shl.code(data, dico)

In [None]:
sparse_code = shl.code(data, dico, matname=matname)

## plotting tools

The simplest solution is to pass a list of figures to the learning method:

In [None]:
df_variable = dico.record['kurt']

In [None]:
df_variable.ndim

In [None]:
list_figures = ['show_dico',  'time_plot_prob',  'time_plot_kurt',  'time_plot_var']
dico = shl.learn_dico(matname=matname, list_figures=list_figures)

But one can also generate every single figure independently:

In [None]:
shl.plot_variance(sparse_code);

In [None]:
shl.plot_variance_histogram(sparse_code);

In [None]:
fig_error, ax_error = None, None
fig_error, ax_error = shl.time_plot(dico, variable='error', fig=fig_error, ax=ax_error, color='blue', label='one');
dico = shl.learn_dico(data=data, dictionary=dico.dictionary, matname=None, list_figures=list_figures)
fig_error, ax_error = shl.time_plot(dico, variable='error', fig=fig_error, ax=ax_error, color='red', label='two');
ax_error.set_ylim(0, .9)
ax_error.legend(loc='best');

And combine them:

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(13, 8))
fig, axs[0] = shl.plot_variance(sparse_code, fig=fig, ax=axs[0])
fig, axs[1] = shl.plot_variance_histogram(sparse_code, fig=fig, ax=axs[1])

## Version used

In [None]:
%load_ext version_information
%version_information numpy, shl_scripts