# Heterogeneous Synapses

This notebook demonstrates how to _efficiently_ apply distinct synapses to each dimension in a vector. The main use case is for when each neuron in a ensemble has a different synaptic model.

In [None]:
import numpy as np

import nengo
from nengo.utils.filter_design import cont2discrete, tf2ss


class HeteroSynapse(object):
    """Callable class for applying different synapses to a vector.
    
    If `elementwise == True`, `len(synapses)` must match `size_in`, in
        which case each synapse is applied separately to each dimension,
        and so `size_out == size_in`.
        
    If `elementwise == False` (default), each synapse is applied to every
        dimension, and so `size_out == size_in * len(synapses)`. The
        output dimensions are ordered by input dimension, such that
        index `i*len(synapses) + j` is the `i`'th input dimension convolved
        with the `j`'th filter.
        
    The latter can be used to connect to a population of neurons with a
    different synapse for each neuron.
    """

    def __init__(self, synapses, dt, elementwise=False, method="zoh"):
        if isinstance(synapses, nengo.synapses.Synapse):
            synapses = [synapses]
        self.synapses = synapses
        self.dt = dt
        self.elementwise = elementwise

        self.A = []
        self.B = []
        self.C = []
        self.D = []
        for synapse in synapses:
            A, B, C, D, _ = cont2discrete(
                tf2ss(synapse.num, synapse.den), dt, method=method)
            self.A.append(A)
            self.B.append(B)
            self.C.append(C)
            self.D.append(D)

        from scipy.linalg import block_diag
        self.A = block_diag(*self.A)
        self.B = block_diag(*self.B) if elementwise else np.vstack(self.B) 
        self.C = block_diag(*self.C)
        self.D = block_diag(*self.D) if elementwise else np.vstack(self.D)

        self._x = None  # set by __call__ once size of u is known

    def __call__(self, t, u):
        # TODO: shape validation
        u = u[:, None] if self.elementwise else u[None, :]
        if self._x is None:
            self._x = np.zeros((len(self.A), u.shape[1]))
        y = np.dot(self.C, self._x) + np.dot(self.D, u)
        self._x = np.dot(self.A, self._x) + np.dot(self.B, u)
        return self.to_vector(y)

    @classmethod
    def to_vector(cls, y):
        return y.flatten(order='F')

## Vector Space Example

This example applies 3 different synapses to each dimension in a 2d-vector, producing a 6d output. It does so efficiently within only a single linear time-invariant system. This also demonstrates use of the `elementwise` flag to specify a distinct synapse per dimension.

In [None]:
n_neurons = 20
dt = 0.0005
T = 0.1
dims_in = 2
synapses = [nengo.Alpha(0.1), nengo.Lowpass(0.005), nengo.Alpha(0.02)]


dims_out = len(synapses)*dims_in
encoders = nengo.dists.UniformHypersphere(surface=True).sample(
    n_neurons, dims_out)
ens_seed = 1


with nengo.Network() as model:
    # Input stimulus
    stim_array = [nengo.Node(output=nengo.processes.WhiteSignal(T))
                  for _ in range(dims_in)]
    stim = nengo.Node(size_in=dims_in)
    for i in range(dims_in):
        nengo.Connection(stim_array[i], stim[i], synapse=None)

    # HeteroSynapse Nodes
    syn_dot = nengo.Node(
        size_in=dims_in, output=HeteroSynapse(synapses, dt))
    syn_elemwise = nengo.Node(
        size_in=dims_out, 
        output=HeteroSynapse(synapses*dims_in, dt, elementwise=True))

    # For comparing results
    x = [nengo.Ensemble(n_neurons, dims_out, seed=ens_seed, encoders=encoders)
         for _ in range(3)]  # expected, actual 1, actual 2

    # Expected
    for j, synapse in enumerate(synapses):
        nengo.Connection(stim, x[0][j::len(synapses)], synapse=synapse)

    # Actual (method #1 = matrix multiplies)
    nengo.Connection(stim, syn_dot, synapse=None)
    nengo.Connection(syn_dot, x[1], synapse=None)

    # Actual (method #2 = elementwise)
    for j in range(len(synapses)):
        nengo.Connection(stim, syn_elemwise[j::len(synapses)], synapse=None)
    nengo.Connection(syn_elemwise, x[2], synapse=None)

    # Probes
    p_exp = nengo.Probe(x[0], synapse=None)
    p_act_dot = nengo.Probe(x[1], synapse=None)
    p_act_elemwise = nengo.Probe(x[2], synapse=None)


# Check correctness
sim = nengo.Simulator(model, dt=dt)
sim.run(T)

assert np.allclose(sim.data[p_act_dot], sim.data[p_exp])
assert np.allclose(sim.data[p_act_elemwise], sim.data[p_exp])

## Neuron Space Example

This example applies 100 different synapses to 100 different neurons. It is an order of magnitude more efficient than separately making all of the connections.

This can also be generalized to the case of representing a vector, but care has to be taken while setting the function between the `HeteroSynapse` and the `Ensemble` in order to properly embed the encoders.

In [None]:
n_neurons = 100
dt = 0.001
T = 0.1

taus = nengo.dists.Uniform(0.001, 0.1).sample(n_neurons)
synapses = [nengo.Lowpass(tau) for tau in taus]
encoders = nengo.dists.UniformHypersphere(surface=True).sample(n_neurons, 1)
transform = np.squeeze(encoders)
ens_seed = 1
signal_seed = 1


with nengo.Network() as model:
    # Input stimulus
    stim = nengo.Node(output=nengo.processes.WhiteSignal(T))

    # HeteroSynapse Nodes
    syn = nengo.Node(size_in=1, output=HeteroSynapse(synapses, dt))

    # For comparing results
    x = [nengo.Ensemble(n_neurons, 1, seed=ens_seed, encoders=encoders)
         for _ in range(2)]  # expected, actual

    # Expected
    inter = nengo.Node(size_in=n_neurons)
    for i, synapse in enumerate(synapses):
        t = np.zeros_like(encoders)
        t[i, :] = encoders[i, :]
        nengo.Connection(stim, x[0].neurons, transform=t, synapse=synapse)

    # Actual
    nengo.Connection(stim, syn, synapse=None)
    nengo.Connection(syn, x[1].neurons, function=lambda x: transform*x,
                     synapse=None)

    # Probes
    p_exp = nengo.Probe(x[0].neurons, synapse=None)
    p_act = nengo.Probe(x[1].neurons, synapse=None)


# Check correctness
sim = nengo.Simulator(model, dt=dt)
sim.run(T)

assert np.allclose(sim.data[p_act], sim.data[p_exp])