# BackProp Sampling Network

This is a continuation of the [Sampling Neural Network](sampling_neural_network.ipynb) notebook.  We noted there that our training process was extremely inefficient.  

Recall that we want to Gibbs-sample our weight values based on:
$$
W_{\alpha} \sim Categorical \Big(softmax\big(\big[log(p(W_{\alpha}=c_k))+\sum_{n=1}^{N_{samples}}log(f(X_n,W_{w_{\alpha}=c_k}))_{Y_n},k \in 1..K\big]\big)\Big)
$$

The problem was that we had to recompute the term 
$$
g(x, w_{w_{\alpha}=c_k}) \equiv logL(Y|X,W) = \sum_{n=1}^{N_{samples}}log(f(x_n,w_{w_{\alpha}=c_k}))_{y_n}
$$
, which corresponds to a full forward pass over all data points, for every $\alpha$ for every possible parameter value $c_k$ for every sampling pass through the parameters.

Instead, we can make the same apporoximation as in BackPropagation - that is, $g(x, w)$ is locally linear in w.

So 
$$
g(x,w_{w_{\alpha}=c_k}) \simeq g(x, w) + \frac{\partial g(x, w)}{\partial w_{\alpha}} (c_k - w_\alpha)
$$

Since we're doing a softmax on the result (over k), a constant shift (first term) does not matter, so we just compute:

$$
W_{\alpha} \sim Categorical \Big(softmax\big(\big[log(p(W_{\alpha}=c_k))+\frac{\partial g(x, w)}{\partial w_{\alpha}} (c_k - w_\alpha),k \in 1..K\big]\big)\Big)
$$


## Experiment

We can run this and compare the results on a Multi-layer perceptron.

## Improvements
- Scale gradient by size_of_minibatch/size_of_full_batch


In [0]:
which_dataset = 'mnist'     # 'mnist' or 'clusters'
n_hidden = 100              # Number of hidden 
mlp_eta = 0.1              # Learning rate of the MLP
gibbs_frac_update = 0.01   # Fraction of units updated with Gibbs
possible_ws = (-1, 0, 1)   # Possible values for w
n_epochs = 20              # Number of epochs
n_test_points = 20          # Number of test points
minibatch_size = 20
test_mode = False

In [0]:
from experimental.sampling_mlp import GibbsSamplingMLP
from plato.tools.cost import negative_log_likelihood
from plato.tools.networks import MultiLayerPerceptron
from plato.tools.online_prediction.online_predictors import GradientBasedPredictor
from plato.tools.optimizers import SimpleGradientDescent
from utils.benchmarks.predictor_comparison import compare_predictors
from utils.datasets.mnist import get_mnist_dataset
from utils.datasets.synthetic_clusters import get_synthetic_clusters_dataset
from utils.tools.mymath import sqrtspace
from general.should_be_builtins import bad_value
import numpy as np

dataset = \
    get_synthetic_clusters_dataset(n_dims=100) if which_dataset == 'clusters' else \
    get_mnist_dataset(flat = True) if which_dataset == 'mnist' else \
    bad_value(which_dataset, 'No dataset named "%s"' % which_dataset)

if test_mode:
    n_epochs = 0.1
    n_test_points = 3

results = compare_predictors(
    dataset = dataset,
    online_predictors={
        'MLP': GradientBasedPredictor(
            function = MultiLayerPerceptron(layer_sizes = [n_hidden, dataset.n_categories], input_size = dataset.input_shape[0], output_activation='softmax', w_init = lambda n_in, n_out: 0.1*np.random.randn(n_in, n_out)),
            cost_function=negative_log_likelihood,
            optimizer=SimpleGradientDescent(eta = mlp_eta),
            ).compile(),
        'Gibbs-MLP': GibbsSamplingMLP(
            layer_sizes = [n_hidden, dataset.n_categories],
            input_size = dataset.input_shape[0],
            possible_ws = possible_ws,
            frac_to_update = gibbs_frac_update,
            output_activation='softmax'
            ).compile(),
        },
    evaluation_function='percent_argmax_correct',
    minibatch_size=minibatch_size,
    accumulators={
        'MLP': None,
        'Gibbs-MLP': 'avg',
        },
    test_epochs=sqrtspace(0, n_epochs, n_test_points)
    )

In [0]:
from utils.benchmarks.plot_learning_curves import plot_learning_curves
from plotting.notebook_plots import link_and_show, set_link_and_show_mode
set_link_and_show_mode()
plot_learning_curves(results)
link_and_show('Gibbs_vs_MLP-%s%s' % (which_dataset, '-TEST' if test_mode else ''))