In [2]:
import sys
sys.path.insert(1, '../')
import torch

from spiral import (
    IntegrateAndFireSoma,
    LeakyMembrane,
    LinearDendrite,
    Axon,
    ConvergentSynapticPlasticity,
    Synapse,
    DisconnectorSynapticCover,
    RandomConnectivity,
    RandomFixedCouplingConnectivity,
    RandomFixedPresynapticPartnersConnectivity,
    AutapseConnectivity,
    NotConnectivity,
    AndConnectivity,
    ScalingResponseFunction,
)
LIF = (LeakyMembrane(IntegrateAndFireSoma))
new_n = lambda shape: LIF(
    name='neuronsA',
    shape=shape,
    batch=1,
    dt=1.,
    analyzable=True,
)
new_axon = lambda: Axon(
    analyzable=True,
    response_function=ScalingResponseFunction(scale=100.),
)
new_dendrite = lambda: LinearDendrite(
    analyzable=True,
    plasticity_model=ConvergentSynapticPlasticity()
)

In [28]:
a = torch.tensor([[1,2,3],[4,5,6]]) % 2 == 0
a

tensor([[False,  True, False],
        [ True, False,  True]])

In [32]:
b = a.reshape(1, 2, 1, 3).repeat(4,1,4,1)

In [41]:
c = torch.zeros_like(b)
print(c[i, :, i, :].shape)
for i in range(4):
    c[i, :, i, :] = a
c.shape

torch.Size([2, 3])


torch.Size([4, 2, 4, 3])

In [53]:
c.sum([1,2])

tensor([[1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]])

In [None]:
from spiral.analysis import FunctionGenerator
I = FunctionGenerator.generate(250, shape=(3,),
           baseline={50*i: 150 if i%2 else 0 for i in range(5)})
I[:,1] *= 1.2

In [None]:
from spiral.analysis import Simulator
import matplotlib.pyplot as plt
from matplotlib_dashboard import MatplotlibDashboard

def simulate_and_plot(n, a, I, title=''):
    s = Simulator(n.progress)
    n.reset()
    s.simulate(inputs={'direct_input': I}, times=200)
    plt.figure(figsize=(14,5))
    md = MatplotlibDashboard([
        ['N',],
        ['S',],
        ['I',]
    ], hspace=.5, wspace=.3)
    n.plot_spikes(md['S'])
    a.plot__neurotransmitter(md['N'])
    md['I'].plot(I, color='blue', alpha=.2)
    md['I'].plot(I.mean(axis=1), color='blue')
    md['I'].set_ylabel('Current')
    md['N'].set_title(title)
    plt.show()

In [None]:
n = new_n()
a = Axon(
    terminal=[2],
    analyzable=True,
)
n.use(a)
simulate_and_plot(n, a, I, 'Without Delay, Default response function')