In [1]:
%matplotlib inline

In [2]:
import matplotlib.pyplot as plt
import numpy as np

In [3]:
from rsnn.rsnn.generator import NetworkGenerator
from rsnn.rsnn.network import Network
from rsnn.spike_train.generator import PeriodicSpikeTrainGenerator
from rsnn.spike_train.periodic_spike_train import MultiChannelPeriodicSpikeTrain, PeriodicSpikeTrain

In [4]:
plt.style.use('paper')

# Create Neural Network

In [5]:
num_neurons = 100
nominal_threshold = 1.0
absolute_refractory = 5.0
relative_refractory = 5.0

num_synapses = 500
synapse_beta = 5.0
synapse_delay_lim = (1.0, 60.0)

period = 100.0
firing_rate = 0.1

## Neural Network

In [6]:
network_generator = NetworkGenerator(
        num_neurons,
        nominal_threshold,
        absolute_refractory,
        relative_refractory,
        num_synapses,
        synapse_beta,
        synapse_delay_lim
    )
network = network_generator.rand()

In [7]:
neuron = network.neurons[7]

In [8]:
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(3.5, 4))

times = np.linspace(0, 100, 1000)

axes[0].plot(times, neuron.synapses[0].response(times), c="C2")
axes[1].plot(times, neuron.adaptive_threshold.response(times), c="C2")

axes[0].set_ylabel(r"$h(t)$")
axes[1].set_ylabel(r"$\eta(t)$")
axes[1].set_xlabel(r"$t$")
axes[1].set_xlim(0, 60)

fig.tight_layout()
#fig.savefig('spike_responses.pgf')
plt.show()

In [9]:
sources = []
for n_ in network.neurons:
    for s_ in n_.synapses:
        sources.append(s_.source.idx)
sources = np.array(sources)

unique, counts = np.unique(sources, return_counts=True)

In [10]:
adjacency_matrix = np.zeros((num_neurons, num_neurons))
for n_ in network.neurons:
    for s_ in n_.synapses:
        adjacency_matrix[s_.source.idx, n_.idx] += 1

In [11]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))

im = ax.imshow(adjacency_matrix, cmap="Greys")

ax.set_xlabel("pre-synaptic neuron")
ax.set_ylabel("post-synaptic neuron")
ax.grid(visible=False)

plt.show()

In [12]:
fig, ax = plt.subplots(1, 1, figsize=(5, 2))
ax.hist(sources, color="C2", bins=50)

ax.set_xlabel("out degree")
ax.set_xlim(sources.min(), sources.max())
ax.set_ylabel("count")

plt.show()

## Spike Trains

In [13]:
spike_train_generator = PeriodicSpikeTrainGenerator(firing_rate, absolute_refractory, relative_refractory)
spike_trains = spike_train_generator.rand(period, num_neurons)

In [14]:
spike_train = spike_trains.spike_trains[neuron.idx]

In [15]:
fig, ax = plt.subplots(1, 1, figsize=(5, 0.8))
ax.stem(
    spike_train.firing_times, 
    np.full_like(spike_train.firing_times, 0.6), 
    basefmt=" ", 
    markerfmt="C2o", 
    linefmt="C2-",
    label=r"$x_7(.)$"
)
ax.set_xlim(0, period)
ax.set_xlabel(r"$t$")
ax.set_ylim(0, 1.4)
#ax.legend()
ax.grid(visible=False)
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)

plt.show()

In [16]:
num_spikes = np.array([spike_train.num_spikes for spike_train in spike_trains.spike_trains])
unique, counts = np.unique(num_spikes, return_counts=True)

In [17]:
fig, ax = plt.subplots(1, 1, figsize=(5, 2))
ax.bar(unique, counts, color="C2")

ax.set_xlabel("num spikes")
ax.set_ylabel("count")

plt.show()

## Memorization

In [18]:
network_l1 = network.copy()
network_l2 = network.copy()

In [19]:
eps, gap, slope = 1.0, 1.0, 0.5
weights_lim = (-0.1, 0.1)

In [20]:
network.memorize(spike_trains, weights_lim, eps, gap, slope)

In [21]:
network_l1.memorize(spike_trains, weights_lim, eps, gap, slope, "l1")

In [22]:
network_l2.memorize(spike_trains, weights_lim, eps, gap, slope, "l2")

In [23]:
neurons_indices = np.repeat(np.arange(num_neurons), num_synapses, axis=0)

weights = np.array([s_.weight for n_ in network.neurons for s_ in n_.synapses])
weights_l1 = np.array([s_.weight for n_ in network_l1.neurons for s_ in n_.synapses])
weights_l2 = np.array([s_.weight for n_ in network_l2.neurons for s_ in n_.synapses])

colors = np.repeat(["C3" if n_.prob.status == "optimal" else "C1" for n_ in network.neurons], num_synapses, axis=0)
colors_l1 = np.repeat(["C3" if n_.prob.status == "optimal" else "C1" for n_ in network_l1.neurons], num_synapses, axis=0)
colors_l2 = np.repeat(["C3" if n_.prob.status == "optimal" else "C1" for n_ in network_l2.neurons], num_synapses, axis=0)

In [24]:
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(5, 2), sharey=True)

axes[0].scatter(neurons_indices, weights, color=colors, s=1, alpha=0.1)
axes[1].scatter(neurons_indices, weights_l1, color=colors, s=1, alpha=0.1)
axes[2].scatter(neurons_indices, weights_l2, color=colors, s=1, alpha=0.1)

axes[0].set_xlabel("postsynaptic neuron")
axes[1].set_xlabel("postsynaptic neuron")
axes[2].set_xlabel("postsynaptic neuron")
axes[0].set_ylabel("synaptic weights")

axes[0].set_title("no reg")
axes[1].set_title("$\ell_1$ reg")
axes[2].set_title("$\ell_2$ reg")

fig.tight_layout()
plt.show()

In [25]:
# save networks
network.save_to_file("networks/network_no_reg.pkl")
network_l1.save_to_file("networks/network_l1_reg.pkl")
network_l2.save_to_file("networks/network_l2_reg.pkl")

In [26]:
# load networks
# network = Network(num_neurons, nominal_threshold, absolute_refractory, relative_refractory)
# network.load_from_file("networks/network_no_reg.pkl")
# 
# network_l1 = Network(num_neurons, nominal_threshold, absolute_refractory, relative_refractory)
# network_l1.load_from_file("networks/network_l1_reg.pkl")
# 
# network_l2 = Network(num_neurons, nominal_threshold, absolute_refractory, relative_refractory)
# network_l2.load_from_file("networks/network_l2_reg.pkl")

In [27]:
raise ValueError()

In [None]:
for n_ in network.neurons:
    n_.spike_train = PeriodicSpikeTrain(
        period,
        spike_trains.spike_trains[n_.idx].firing_times,
    )

In [None]:
neuron.init_template_single(eps, gap, slope)

In [None]:
neuron.memorize(weights_lim)

In [None]:
neuron.spike_train.firing_times

In [None]:
network.memorize(multi_channel_spike_train, SYNAPSE_WEIGHTS_LIM, EPS, GAP, SLOPE, "l1")

In [None]:
times = np.linspace(0, period, 10000)

In [None]:
def input_spike_resp(t):
        
    #tmax = (math.log(synapse_decay) - math.log(soma_decay)) / (1 / soma_decay - 1 / synapse_decay)
    #gamma = 1 / (math.exp(-tmax / soma_decay) - math.exp(-tmax / synapse_decay))

    if isinstance(t, np.ndarray):
        #z = gamma * (np.exp(-t / soma_decay) - np.exp(-t / synapse_decay))
        z = t / 5 * np.exp(1 - t / 5)
        z[t < 0] = 0.0
        return z
    
    if t < 0:
        return 0.0
    
    #return gamma * (np.exp(-t / soma_decay) - np.exp(-t / synapse_decay))
    return t / 5 * np.exp(1 - t / 5)
    
def input_spike_resp_deriv(t):
    tmax = (math.log(synapse_decay) - math.log(soma_decay)) / (1 / soma_decay - 1 / synapse_decay)
    gamma = 1 / (math.exp(-tmax / soma_decay) - math.exp(-tmax / synapse_decay))

    if isinstance(t, np.ndarray):
        z = gamma * (np.exp(-t / synapse_decay) / synapse_decay - np.exp(-t / soma_decay) / soma_decay)
        z[t < 0] = 0.0
        return z
    
    if t < 0:
        return 0.0
    
    return gamma * (np.exp(-t / synapse_decay) / synapse_decay - np.exp(-t / soma_decay) / soma_decay)

def refractory_spike_resp(t):
    if isinstance(t, np.ndarray):
        z = -refractory_weight * np.exp(- (t - abs_refractory_period) / soma_decay)
        z[t <= abs_refractory_period] = -np.inf
        z[t <= 0] = 0.0
        return z
    
    if t <= 0:
        return 0.0
    
    if t <= abs_refractory_period:
        return -np.inf
    
    return -refractory_weight * np.exp(- (t - abs_refractory_period) / soma_decay)

def refractory_spike_resp_deriv(t):
    if isinstance(t, np.ndarray):
        z = refractory_weight / soma_decay * np.exp(- (t - abs_refractory_period) / soma_decay)
        z[t <= abs_refractory_period] = 0.0
        return z

    if t <= abs_refractory_period:
        return 0.0
    
    return refractory_weight / soma_decay * np.exp(- (t - abs_refractory_period) / soma_decay)

In [None]:
def potential(t, delays, weights, input_spike_trains, modulo=True):
    if isinstance(t, np.ndarray):
        if modulo:
            y = np.array([
                np.sum(input_spike_resp((t[None,:] - delays[k] - input_spike_trains.spike_trains[k].firing_times[:,None])%period), axis=0) for k in range(num_inputs)
            ])
            return weights@y
        
        y = np.array([
            np.sum(input_spike_resp((t[None,:] - delays[k] - input_spike_trains.spike_trains[k].firing_times[:,None])), axis=0) for k in range(num_inputs)
        ])
        return weights@y  
        
    if modulo:
        y = np.array([
            np.sum(input_spike_resp((t - delays[k] - input_spike_trains.spike_trains[k].firing_times)%period), axis=0) for k in range(num_inputs)
        ])
        return weights@y
    
    y = np.array([
        np.sum(input_spike_resp((t - delays[k] - input_spike_trains.spike_trains[k].firing_times)), axis=0) for k in range(num_inputs)
    ])
    return weights@y  

def potential_deriv(t, delays, weights, input_spike_trains, modulo=True):
    if isinstance(t, np.ndarray):
        if modulo:
            y = np.array([
                np.sum(input_spike_resp_deriv((t[None,:] - delays[k] - input_spike_trains.spike_trains[k].firing_times[:,None])%period), axis=0) for k in range(num_inputs)
                ])
            return weights@y
        
        y = np.array([
                np.sum(input_spike_resp_deriv((t[None,:] - delays[k] - input_spike_trains.spike_trains[k].firing_times[:,None])), axis=0) for k in range(num_inputs)
            ])
        return weights@y
    
    if modulo:
        y = np.array([
            np.sum(input_spike_resp((t - delays[k] - input_spike_trains.spike_trains[k].firing_times)%period), axis=0) for k in range(num_inputs)
            ])
        return weights@y  
    
    y = np.array([
            np.sum(input_spike_resp((t - delays[k] - input_spike_trains.spike_trains[k].firing_times)), axis=0) for k in range(num_inputs)
        ])
    return weights@y

def threshold(t, target_spike_train, modulo=True):
    if isinstance(t, np.ndarray):
        if modulo: 
            return firing_threshold - np.sum(refractory_spike_resp((t[None,:] - target_spike_train.firing_times[:,None])%period), axis=0)

        return firing_threshold - np.sum(refractory_spike_resp((t[None,:] - target_spike_train.firing_times[:,None])), axis=0)
    
    if modulo: 
        return firing_threshold - np.sum(refractory_spike_resp((t - target_spike_train.firing_times)%period), axis=0)        
    return firing_threshold - np.sum(refractory_spike_resp((t - target_spike_train.firing_times)), axis=0)

In [None]:
1.0 + np.exp(-((0 - 53.55974238 - 5.0) % 100)/5.0)

In [None]:
#def threshold(t, target_spike_train, modulo=True):
#    if isinstance(t, np.ndarray):
#        if modulo: 
#            return firing_threshold - np.sum(refractory_spike_resp((t[None,:] - target_spike_train.firing_times[:,None])%period), axis=0)
#
#        return firing_threshold - np.sum(refractory_spike_resp((t[None,:] - target_spike_train.firing_times[:,None])), axis=0)
#    
#    if modulo: 
#        return firing_threshold - np.sum(refractory_spike_resp((t - target_spike_train.firing_times[:,None])%period), axis=0)        
#    return firing_threshold - np.sum(refractory_spike_resp((t - target_spike_train.firing_times[:,None])), axis=0)

In [None]:
#def threshold_deriv(t, target_spike_train):
#    if isinstance(t, np.ndarray):
#        return - np.sum(refractory_spike_resp_deriv((t[None,:] - target_spike_train.firing_times[:,None])), axis=0)
#    return - np.sum(refractory_spike_resp_deriv((t - target_spike_train.firing_times[:,None])), axis=0)

In [None]:
input_spike_resp((0 - 92.17851908 - 1)%100)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(3.5, 4))

times = np.linspace(0, 100, 1000)

axes[0].plot(times, input_spike_resp(times), c="C2", label=r"$g_0 * g_k$")
axes[1].plot(times, -refractory_spike_resp(times), c="C2", label=r"$-g_0$")
#axes[0].legend()
#axes[1].legend()
#axes[0].set_xlabel(r"$t$")
#axes[0].set_xlim(0, period)
axes[0].set_ylabel(r"$h(t)$")
axes[1].set_ylabel(r"$\eta(t)$")
axes[1].set_xlabel(r"$t$")
axes[1].set_xlim(0, 100)

fig.tight_layout()
fig.savefig('spike_responses.pgf')
#plt.show()

In [None]:
spike_train_generator = PeriodicSpikeTrainGenerator(firing_rate, abs_refractory_period, rel_refractory_period)

In [None]:
input_spike_trains = spike_train_generator.rand(period, num_inputs)
#input_spike_train = spike_train_generator.rand(period)
#input_spike_trains = MultiChannelPeriodicSpikeTrain(
#    period, 
#    num_inputs,
#    [input_spike_train.firing_times]*num_inputs
#)
target_spike_train = spike_train_generator.rand(period)

In [None]:
num_spikes = np.array([spike_train.num_spikes for spike_train in input_spike_trains.spike_trains])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 0.8))
ax.stem(target_spike_train.firing_times, np.ones_like(target_spike_train.firing_times), basefmt=" ", markerfmt="C2o", linefmt="C2-")
ax.set_xlim(0, period)
ax.set_ylim(0, 2)
ax.set_yticks([])
plt.show()

In [None]:
unique, counts = np.unique(num_spikes, return_counts=True)

fig, ax = plt.subplots(1, 1, figsize=(5, 2))
ax.bar(unique, counts, color="C2")
#ax.set_xticks(np.arange(np.min(num_spikes), np.max(num_spikes)+1))
plt.show()

In [None]:
delays = rng.uniform(1.0, 60.0, num_inputs)

# Computing Weights

In [None]:
def compute_spike_train(weights):
    spike_train = target_spike_train.copy()
    
    fun = lambda t_: potential(t_, delays, weights, input_spike_trains) - threshold(t_, spike_train, modulo=False)

    ta = spike_train.firing_times[-1] + abs_refractory_period
    while(ta < 2*period):
        tb = ta + 1
        if fun(ta) * fun(tb) <= 0:
            t = brentq(fun, ta, tb)
            spike_train.append(t)
            ta += abs_refractory_period
        else:
            ta += 1

    return spike_train

In [None]:
extended_firing_times = np.concatenate([
    target_spike_train.firing_times - period,
    target_spike_train.firing_times,
    target_spike_train.firing_times + period
])

times = np.arange(0, period, 1e-2)
time_diff = times[:,None] - extended_firing_times[None,:]
mask_slope = np.any(np.abs(time_diff) < 1.0, axis=1)
t_slope = times[mask_slope]
y_slope = np.stack(
    [np.sum(
        input_spike_resp_deriv(
            (t_slope[:,None] - delays[k] - input_spike_trains.spike_trains[k].firing_times[None,:]) % period
        ),
        axis=1) for k in range(num_inputs)], 
    axis=1)
z_slope = np.full_like(t_slope, 0.5)

times = np.arange(0, period, 1e-2)
time_diff = times[:,None] - extended_firing_times[None,:]
mask_level_f = np.any((time_diff > -1.0) & (time_diff < 0.0), axis=1)
t_level_f = times[mask_level_f]
y_level_f = np.stack(
    [np.sum(
        input_spike_resp(
            (t_level_f[:,None] - delays[k] - input_spike_trains.spike_trains[k].firing_times[None,:]) % period
        ),
        axis=1) for k in range(num_inputs)], 
    axis=1)
z_level_f = firing_threshold - np.sum(
    refractory_spike_resp(
        (t_level_f[:,None] - target_spike_train.firing_times[None,:]) % period
    ), axis=1)

times = np.arange(0, period, 1e-1)
time_diff = times[:,None] - extended_firing_times[None,:]
mask_level = np.all((time_diff > abs_refractory_period) | (time_diff < -1.0), axis=1)
t_level = times[mask_level]
y_level = np.stack(
    [np.sum(
        input_spike_resp(
            (t_level[:,None] - delays[k] - input_spike_trains.spike_trains[k].firing_times[None,:]) % period
        ),
        axis=1) for k in range(num_inputs)], 
    axis=1)
z_level = 0.0 - np.sum(
    refractory_spike_resp(
        (t_level[:,None] - target_spike_train.firing_times[None,:]) % period
    ), axis=1)

y_firing = np.stack(
    [np.sum(
        input_spike_resp(
            (target_spike_train.firing_times[:,None] - delays[k] - input_spike_trains.spike_trains[k].firing_times[None,:]) % period
        ),
        axis=1) for k in range(num_inputs)], 
    axis=1)
z_firing = firing_threshold - np.sum(
    refractory_spike_resp(
        (target_spike_train.firing_times[:,None] - target_spike_train.firing_times[None,:]) % period
    ), axis=1)

## Bounded Weights

In [None]:
weights = cp.Variable(num_inputs)
objective = cp.Minimize(0.0)
constraints = [
    y_firing @ weights == z_firing, 
    y_level_f @ weights <= z_level_f, 
    y_level @ weights <= z_level, 
    y_slope @ weights >= z_slope,
    weights >= wmin,
    weights <= wmax
]
prob = cp.Problem(objective, constraints)

In [None]:
_ = prob.solve(solver="GUROBI")
print("problem is", prob.status)

In [None]:
times = np.linspace(0, period, 1000)

fig = plt.figure(layout="constrained")
subplots = fig.subfigures()

axes_template = subplots.subplots(nrows=2, ncols=1)

_ = axes_template[0].scatter(np.arange(num_inputs), np.sort(weights.value), s=2, c="C2", label=r"$w$")
_ = axes_template[0].set_xlim(0, num_inputs)
_ = axes_template[0].set_ylim(wmin, wmax)
_ = axes_template[0].set_xlabel(r"$k$")
_ = axes_template[0].set_ylabel(r"$w_k$")
#_ = axes_template[0].axhline(wmin, linestyle="--", c="C1")
#_ = axes_template[0].axhline(wmax, linestyle="--", c="C1")

for s in target_spike_train.firing_times:
    _ = axes_template[1].axvline(s, linestyle="--", c="C0")
_ = axes_template[1].plot(times, potential(times, delays, weights.value, input_spike_trains), c="C2", label=r"$w$")
_ = axes_template[1].plot(times, threshold(times, target_spike_train), c="C1")
_ = axes_template[1].set_xlim(0, period)
_ = axes_template[1].set_xlabel(r"$t$")
_ = axes_template[1].set_ylabel(r"$z(t)$")

In [None]:
spike_train = compute_spike_train(weights.value)
spike_train_ = SpikeTrain(spike_train.firing_times[spike_train.firing_times > period])
corr, lag = single_channel_correlation(target_spike_train, spike_train_, eps=1e-3)

print(f"spike train similarity: {corr:.3f} (with a lag of {lag:.3f})" )
print(f"sparsity: {1 - np.count_nonzero(weights.value)/num_inputs:.6f}")
print(f"energy: {np.sum(np.square(weights.value))/num_inputs:.6f}")

## Bounded Weights with L2 Regularization

In [None]:
weights_r2 = cp.Variable(num_inputs)
objective = cp.Minimize(cp.norm2(weights_r2))
constraints = [
    y_firing @ weights_r2 == z_firing, 
    y_level_f @ weights_r2 <= z_level_f, 
    y_level @ weights_r2 <= z_level, 
    y_slope @ weights_r2 >= z_slope,
    weights_r2 >= wmin,
    weights_r2 <= wmax
]
prob = cp.Problem(objective, constraints)

In [None]:
_ = prob.solve(solver="GUROBI")
print("problem is", prob.status)

In [None]:
times = np.linspace(0, period, 1000)

fig = plt.figure(layout="constrained")
subplots = fig.subfigures()

axes_template = subplots.subplots(nrows=2, ncols=1)

_ = axes_template[0].scatter(np.arange(num_inputs), np.sort(weights_r2.value), s=2, c="C2", label=r"$w$")
_ = axes_template[0].set_xlim(0, num_inputs)
_ = axes_template[0].set_ylim(wmin, wmax)
_ = axes_template[0].set_xlabel(r"$k$")
_ = axes_template[0].set_ylabel(r"$w_k$")
#_ = axes_template[0].axhline(wmin, linestyle="--", c="C1")
#_ = axes_template[0].axhline(wmax, linestyle="--", c="C1")

for s in target_spike_train.firing_times:
    _ = axes_template[1].axvline(s, linestyle="--", c="C0")
_ = axes_template[1].plot(times, potential(times, delays, weights_r2.value, input_spike_trains), c="C2", label=r"$w$")
_ = axes_template[1].plot(times, threshold(times, target_spike_train), c="C1")
_ = axes_template[1].set_xlim(0, period)
_ = axes_template[1].set_xlabel(r"$t$")
_ = axes_template[1].set_ylabel(r"$z(t)$")

In [None]:
spike_train = compute_spike_train(weights.value)
spike_train_ = SpikeTrain(spike_train.firing_times[spike_train.firing_times > period])
corr, lag = single_channel_correlation(target_spike_train, spike_train_, eps=1e-3)

print(f"spike train similarity: {corr:.3f} (with a lag of {lag:.3f})" )
print(f"sparsity: {1 - np.count_nonzero(weights.value)/num_inputs:.6f}")
print(f"energy: {np.sum(np.square(weights.value))/num_inputs:.6f}")

## Bounded Weights with L1 Regularization

In [None]:
weights_r1 = cp.Variable(num_inputs)
objective = cp.Minimize(cp.norm1(weights_r1))
constraints = [
    y_firing @ weights_r1 == z_firing, 
    y_level_f @ weights_r1 <= z_level_f, 
    y_level @ weights_r1 <= z_level, 
    y_slope @ weights_r1 >= z_slope,
    weights_r1 >= wmin,
    weights_r1 <= wmax
]
prob = cp.Problem(objective, constraints)

In [None]:
_ = prob.solve(solver="GUROBI")
print("problem is", prob.status)

In [None]:
times = np.linspace(0, period, 1000)

fig = plt.figure(layout="constrained")
subplots = fig.subfigures()

axes_template = subplots.subplots(nrows=2, ncols=1)

_ = axes_template[0].scatter(np.arange(num_inputs), np.sort(weights_r1.value), s=2, c="C2", label=r"$w$")
_ = axes_template[0].set_xlim(0, num_inputs)
_ = axes_template[0].set_ylim(wmin, wmax)
_ = axes_template[0].set_xlabel(r"$k$")
_ = axes_template[0].set_ylabel(r"$w_k$")
#_ = axes_template[0].axhline(wmin, linestyle="--", c="C1")
#_ = axes_template[0].axhline(wmax, linestyle="--", c="C1")

for s in target_spike_train.firing_times:
    _ = axes_template[1].axvline(s, linestyle="--", c="C0")
_ = axes_template[1].plot(times, potential(times, delays, weights_r1.value, input_spike_trains), c="C2", label=r"$w$")
_ = axes_template[1].plot(times, threshold(times, target_spike_train), c="C1")
_ = axes_template[1].set_xlim(0, period)
_ = axes_template[1].set_xlabel(r"$t$")
_ = axes_template[1].set_ylabel(r"$z(t)$")

In [None]:
spike_train = compute_spike_train(weights_r1.value)
spike_train_ = SpikeTrain(spike_train.firing_times[spike_train.firing_times > period])
corr, lag = single_channel_correlation(target_spike_train, spike_train_, eps=1e-3)

print(f"spike train similarity: {corr:.3f} (with a lag of {lag:.3f})" )
print(f"sparsity: {1 - np.count_nonzero(weights_r1.value)/num_inputs:.6f}")
print(f"energy: {np.sum(np.square(weights_r1.value))/num_inputs:.6f}")

# Comparison

In [None]:
times = np.linspace(0, period, 1000)

fig = plt.figure(layout="constrained", figsize=(3.4, 5))
subplots = fig.subfigures()

axes_template = subplots.subplots(nrows=3, ncols=1, sharex=True)

axes_template[0].scatter(np.arange(num_inputs), weights.value, s=2, c="C1", label="w/o reg")
axes_template[1].scatter(np.arange(num_inputs), weights_r2.value, s=2, c="C2", label=r"$\ell_2$ reg")
axes_template[2].scatter(np.arange(num_inputs), weights_r1.value, s=2, c="C3", label=r"$\ell_1$ reg")

axes_template[2].set_xlim(0, num_inputs)
axes_template[2].set_xlabel(r"$k$")

axes_template[0].set_ylim(wmin, wmax)
axes_template[0].set_ylabel(r"$w_k$")
axes_template[1].set_ylim(wmin, wmax)
axes_template[1].set_ylabel(r"$w_k$")
axes_template[2].set_ylim(wmin, wmax)
axes_template[2].set_ylabel(r"$w_k$")

plt.show()

#plt.figlegend(loc='center left', bbox_to_anchor=(1, 0.5))

#fig.savefig('weights_reg.pgf')