# Training an SNN using Neuroevolution!

Featuring Regularization to try reducing silent neurons!!!

In [16]:
import spyx
import spyx.nn as snn

# JAX imports
import os
import jax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".70"
from jax import numpy as jnp
import jmp
import numpy as np

from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# optimize the parameters using evosax
import evosax
from evosax.strategies import LM_MA_ES as LMMAES

# rendering tools
import matplotlib.pyplot as plt
%matplotlib notebook
import graphviz
import mediapy as media

In [17]:
policy = jmp.get_policy('half')

hk.mixed_precision.set_policy(hk.Linear, policy)
hk.mixed_precision.set_policy(hk.Conv2D, policy)
hk.mixed_precision.set_policy(hk.MaxPool, policy)
hk.mixed_precision.set_policy(spyx.activation.ActivityRegularization, policy)
hk.mixed_precision.set_policy(snn.LIF, policy)
hk.mixed_precision.set_policy(snn.LI, policy)

## Data Loading

In [18]:
mnist_dl = spyx.data.MNIST_loader(batch_size=64)

In [19]:
mnist_dl.train_step().obs.shape

(64, 64, 28, 28, 1)

## SNN

In [20]:
def mnist_snn(x):
    # seqs is [T, F].
    core = hk.DeepRNN([
        hk.Conv2D(16, 3, with_bias=False),
        snn.LIF((28,28,16), beta=0.8, activation=spyx.activation.Heaviside()),
        spyx.activation.ActivityRegularization(),
        hk.MaxPool(2,2, "SAME"),
        hk.Conv2D(32, 3, with_bias=False),
        snn.LIF((14,14,32), beta=0.8, activation=spyx.activation.Heaviside()),
        spyx.activation.ActivityRegularization(),
        hk.MaxPool(2, 2, "SAME"),
        hk.Flatten(),
        hk.Linear(10, with_bias=False),
        snn.LI((10,), beta=0.9)
    ])
    spikes, V = hk.dynamic_unroll(core, x.astype(jnp.float32), core.initial_state(x.shape[0]), time_major=False, unroll=16)
    return spikes, V

In [21]:
key = jax.random.PRNGKey(0)
SNN = hk.without_apply_rng(hk.transform_with_state(mnist_snn))
params, reg_init = SNN.init(rng=key, x=mnist_dl.train_step().obs)

## Evolution

In [22]:
def evolution(SNN, params, dl, epochs=25, test_every=1, key=0):

    rng = jax.random.PRNGKey(key)
    param_reshaper = evosax.ParameterReshaper(params)
    
    aug = spyx.data.shift_augment(max_shift=3, axes=(-3,-2))

    strategy = LMMAES(popsize=96,
                  num_dims=param_reshaper.total_params,
                )

    es_params = strategy.default_params
    es_params = es_params.replace()

    state = strategy.initialize(rng)

    @jax.jit
    def net_eval(individual, events, targets):
        readout, spike_counts = SNN.apply(individual, reg_init, events)
        traces, V = readout
        acc, _ = spyx.loss.integral_accuracy(traces, targets)
        return acc #xe_loss + 0.1 * reg_loss, xe_loss
    
    @jax.jit
    def net_test(individual, events, targets):
        readout, spike_counts = SNN.apply(individual, reg_init, events)
        traces, V = readout
        acc, pred = spyx.loss.integral_accuracy(traces, targets)
        return acc
    
    sim_fn = jax.vmap(net_eval, (0, None, None))
    test_fn = jax.vmap(net_test, (0, None, None))
    
    @jax.jit
    def step(rng, state, events, targets):
        rng, rng_ask = jax.random.split(rng, 2)
        # ASK
        pop, state = strategy.ask(rng_ask, state)
        population_params = param_reshaper.reshape(pop.astype(jnp.float32))
        # EVAL
        acc = sim_fn(population_params, events, targets) 
        # TELL
        state = strategy.tell(pop, -acc, state)        
        
        return rng, state, acc
    
    
    for gen in range(epochs):
        pbar = tqdm([*range(dl.train_len//dl.batch_size)])
        pbar.set_description("Epoch #{}".format(gen))
        dl.train_reset()
        for _ in pbar:
            events, targets = dl.train_step() # non-jittable...

            rng, state, acc = step(rng, state, events, targets)
            
            pbar.set_postfix(Fitness=jnp.max(acc))
        
        elite = param_reshaper.reshape(jnp.array([state.best_member]))
        if gen % test_every == test_every-1:
            dl.val_reset()
            accs = []
            
            pbar = tqdm([*range(dl.val_len//dl.batch_size)])
            pbar.set_description("Validate")
            for _ in pbar:
                events, targets = dl.val_step()
                
                acc = test_fn(elite, events, targets)
                
                accs.append(acc)
                
                pbar.set_postfix(Fitness=np.mean(accs))
        
    return jax.tree_util.tree_map(jnp.squeeze, elite)

In [23]:
elite_params = evolution(SNN, params, mnist_dl)

ParameterReshaper: 20432 parameters detected for optimization.


Epoch #0: 100%|████████████████| 656/656 [04:45<00:00,  2.30it/s, Loss=0.671875]
Validate: 100%|███| 281/281 [00:07<00:00, 38.53it/s, Accuracy=0.661, Loss=-.661]
Epoch #1: 100%|████████████████| 656/656 [04:40<00:00,  2.34it/s, Loss=0.703125]
Validate: 100%|███| 281/281 [00:06<00:00, 43.72it/s, Accuracy=0.723, Loss=-.723]
Epoch #2: 100%|████████████████| 656/656 [04:40<00:00,  2.34it/s, Loss=0.796875]
Validate: 100%|█████| 281/281 [00:06<00:00, 43.70it/s, Accuracy=0.71, Loss=-.71]
Epoch #3: 100%|██████████████████| 656/656 [04:40<00:00,  2.34it/s, Loss=0.8125]
Validate: 100%|███| 281/281 [00:06<00:00, 43.63it/s, Accuracy=0.708, Loss=-.708]
Epoch #4: 100%|███████████████████| 656/656 [04:40<00:00,  2.34it/s, Loss=0.875]
Validate: 100%|███| 281/281 [00:06<00:00, 44.10it/s, Accuracy=0.707, Loss=-.707]
Epoch #5: 100%|█████████████████| 656/656 [04:40<00:00,  2.34it/s, Loss=0.84375]
Validate: 100%|███| 281/281 [00:06<00:00, 43.53it/s, Accuracy=0.708, Loss=-.708]
Epoch #6: 100%|█████████████

In [8]:
elite_params = evolution(SNN, params, mnist_dl)

ParameterReshaper: 9064 parameters detected for optimization.


Epoch #0: 100%|████████████████| 656/656 [02:28<00:00,  4.42it/s, Loss=0.640625]
Validate: 100%|███| 281/281 [00:06<00:00, 41.09it/s, Accuracy=0.673, Loss=-.673]
Epoch #1: 100%|█████████████████| 656/656 [02:22<00:00,  4.59it/s, Loss=0.78125]
Validate: 100%|███| 281/281 [00:03<00:00, 91.85it/s, Accuracy=0.698, Loss=-.698]
Epoch #2: 100%|████████████████████| 656/656 [02:19<00:00,  4.69it/s, Loss=0.75]
Validate: 100%|███████| 281/281 [00:03<00:00, 91.79it/s, Accuracy=0.7, Loss=-.7]
Epoch #3: 100%|████████████████████| 656/656 [02:23<00:00,  4.58it/s, Loss=0.75]
Validate: 100%|███| 281/281 [00:05<00:00, 47.13it/s, Accuracy=0.699, Loss=-.699]
Epoch #4: 100%|████████████████████| 656/656 [02:23<00:00,  4.58it/s, Loss=0.75]
Validate: 100%|███████| 281/281 [00:05<00:00, 47.04it/s, Accuracy=0.7, Loss=-.7]
Epoch #5: 100%|█████████████████| 656/656 [02:22<00:00,  4.60it/s, Loss=0.78125]
Validate: 100%|███| 281/281 [00:05<00:00, 46.91it/s, Accuracy=0.697, Loss=-.697]
Epoch #6: 100%|█████████████

In [9]:
def plot_readout(data, tgt_label):
    plt.imshow(data.T, aspect="auto")
    plt.title("Readout Activations")
    plt.xlabel("Time")
    plt.ylabel("Class")
    plt.title("Class Label: {}".format(tgt_label))
    plt.colorbar()
    plt.yticks(range(data.shape[-1]))
    plt.show()
    

In [10]:
mnist_dl.train_reset()
sample = mnist_dl.train_step()

In [11]:
sample.obs[0].shape

(64, 28, 28, 1)

In [12]:
elite_params

{'conv2_d': {'w': Array([[[ -0.12038945,   2.347349  ,   4.190474  ,   0.09780852,
            -3.7764838 ,  -2.2755766 ,  -0.43088043,   0.71603346],
          [  1.348577  ,  10.1645975 ,  -1.2470248 ,  -1.9695814 ,
             0.98805493,   6.39167   ,   1.1576335 ,   4.3169594 ],
          [ -0.04378413,   2.1139872 ,  -4.3617873 ,  -0.05585323,
            -4.0712934 ,   3.565277  ,   1.0488793 ,   6.50354   ]],
  
         [[ -0.5222165 ,  -1.2852201 ,   5.2829056 ,  -4.9608846 ,
            -4.401705  ,   3.7529924 ,  -6.7730412 ,  -1.6442848 ],
          [ -1.3024524 ,   0.9589633 ,   1.4926426 ,  -1.4066026 ,
             5.9444876 ,   6.014889  ,   4.5523415 ,   7.4128704 ],
          [ -1.7314181 , -10.909041  ,   1.7575699 ,   1.1773806 ,
            -4.0330834 ,  -3.0933728 ,  -0.4524294 ,  -5.7398677 ]],
  
         [[  4.0371323 ,   0.23270608,  -2.8066769 ,  -9.832846  ,
             1.0465441 ,  -5.4976954 ,  -0.72520375,  -2.6119177 ],
          [  2.7469723 ,   0.86

In [13]:
readout, spk_cts = SNN.apply(elite_params, reg_init, sample.obs)

ValueError: 'conv2_d/w' with retrieved shape (3, 3, 8) does not match shape=(3, 3, 1, 8) dtype=dtype('float16')

In [None]:
spk_cts

In [None]:
s, v = readout

In [None]:
s

In [None]:
jnp.sum(s) # the network is almost entirely silent!!!!!!

In [None]:
plot_readout(s[0], sample.labels[0])

In [14]:
def test(SNN, in_params, dl):

    @jax.jit
    def net_test(individual, events, targets):
        readout, spike_counts = SNN.apply(individual, reg_init, events)
        traces, V = readout
        acc, pred = spyx.loss.integral_accuracy(traces, targets)
        xe_loss = spyx.loss.integral_crossentropy(traces, targets)
        reg_loss = spyx.loss.mse_spike_count_reg(spike_counts, 16)
        loss = xe_loss + reg_loss
        return acc, loss
    
    dl.test_reset()
    accs = []
    #preds = []
    losses = []
    
    pbar = tqdm([*range(dl.test_len//dl.batch_size)])
    pbar.set_description("Validating")
    for _ in pbar:
        events, targets = dl.test_step()
        
        acc, loss = net_test(in_params, events, targets)
        
        accs.append(acc)
        #preds.append(pred)
        losses.append(loss)
        
        pbar.set_postfix(Loss=np.mean(losses), Accuracy=np.mean(accs))
    
    return accs, losses

In [15]:
acc, losses = test(SNN, elite_params, mnist_dl)

Validating:   0%|                                       | 0/156 [00:00<?, ?it/s]


ValueError: 'conv2_d/w' with retrieved shape (3, 3, 8) does not match shape=(3, 3, 1, 8) dtype=dtype('float16')