In [None]:
import brian2 as b2
import matplotlib.pyplot as plt
import brian2tools as b2t
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d

from functions import *

# Synapsis

In [None]:
b2.start_scope()

N = 100
v_L = -67
v_I = -80
v_thr = -40
v_res = -70
duration = 1000*b2.ms
sigma = 0.1
I_dc = 3.4/b2.ms
g_L = 0.1/b2.ms
tau = 0.1*b2.ms
tau_D = 20*b2.ms


In [None]:
eqs = '''
    dv/dt = g_L*(v_L-v) + I_dc + g*(v_I - v) + sigma*xi*tau**-0.5 : 1
    dg/dt = (R-g)/tau_D : hertz
    dR/dt = -R/tau_D : hertz
'''

G = b2.NeuronGroup(
        N,
        eqs, 
        threshold='v>v_thr',
        reset='v=v_res'
    )

initial_v =  np.random.uniform(low=v_res, high=v_thr, size=N)
G.v =  np.random.uniform(low=v_res, high=v_thr, size=N)
G.g = 100/b2.second


In [None]:
S2 = b2.Synapses(G, G, model='''w : hertz # synaptic weight''', on_pre='R+=w')

S2.connect(condition='i!=j')
S2.w = 4*b2.hertz


In [None]:

spike_mon = b2.SpikeMonitor(G)
monitor = b2.StateMonitor(G, variables=True, record=True)
mon_syn = b2.StateMonitor(S2, variables=True, record=True)
rate_mon = b2.PopulationRateMonitor(G)


# a simple run would not include the monitors
net = b2.Network(b2.collect())  # automatically include G and S
net.add(monitor)  # manually add the monitors

net.run(duration)


In [None]:
plt.figure(figsize=(18,4))
for neuron_id in range(4,5):
    b2.plot(monitor.t/b2.ms, monitor.v[neuron_id], label=f'Neuron {neuron_id}')


In [None]:
b2t.brian_plot(spike_mon)

In [None]:

class VoltageMap:
    def __init__(self, voltages, spikes, duration, v_res, v_thr, v_resolution=500):
        self.voltages = voltages
        self.v_res = v_res
        self.v_thr = v_thr
        self.v_resolution = v_resolution
        self.duration = duration
        self.time_steps = voltages.shape[1]
        self.N_neurons = voltages.shape[0]
        self.density = np.empty((v_resolution, self.time_steps))
        self.density_computed = False
        self.spikes = pd.DataFrame.from_dict(spikes, orient='index').stack()
        self.interspike_times = self.spikes.groupby(level=0).diff()

    def get_density(self):
        for t in range(self.time_steps):
            self.density[:, t] = np.histogram(
                self.voltages[:, t],
                range=(self.v_res, self.v_thr),
                bins=self.v_resolution,
            )[0]
        self.density_computed = True

    def plot_density(self):
        if not self.density_computed:
            self.get_density()

        fig, ax = plt.subplots(figsize=(20, 8))
        ax.set_xlabel('Time (ms)')
        ax.set_ylabel('Voltage (mV)')
        xgrid = np.linspace(0, self.duration, self.time_steps)
        ygrid = np.linspace(self.v_res, self.v_thr, self.v_resolution)
        v_map = ax.pcolormesh(xgrid, ygrid, self.density, shading='auto')
        fig.colorbar(v_map)
        return v_map

    def plot_rasterplot(self):
        if not self.density_computed:
            self.get_density()

        fig, ax = plt.subplots(figsize=(20, 8))
        ax.set_xlabel('Time (ms)')
        ax.set_ylabel('Neuron Cell')
        xgrid = np.linspace(0, self.duration, self.time_steps)
        ygrid = np.array(range(self.N_neurons))
        v_map = ax.pcolormesh(xgrid, ygrid, self.voltages, shading='auto')
        fig.colorbar(v_map)
        return v_map


In [None]:
volts = np.array(monitor.v)
v_map = VoltageMap(volts, duration=duration, spikes=spike_mon.values('t'), v_res=v_res, v_thr=v_thr, v_resolution=200)
v_map.plot_density()

In [None]:
v_map.plot_rasterplot()

In [None]:
v_map.spikes

In [None]:
rate_mon.smooth_rate(window='gaussian', width=25*b2.second)

In [None]:
global_rate = pd.DataFrame(
    {
        'rate': rate_mon.rate/b2.hertz,
    },
)

global_rate['time'] = rate_mon.t
for index in range(len(rate_mon)):
    = rate_mon.t[index]

In [None]:
global_rate['time']

In [None]:
pd.Series(rate_mon.t)

In [None]:
global_rate['time'].iloc[1:3]

In [None]:
global_rate['rate_smooth'] = gaussian_filter1d(
    input=global_rate['rate'],
    sigma=25
) 

In [None]:
global_rate['time']

In [None]:
peak_indices = find_peaks(global_rate['rate_smooth'])[0]
global_spikes = pd.DataFrame({
    'peak_time': global_rate['time'].iloc[peak_indices]
})

In [None]:
global_spikes = pd.concat(
    [
        pd.DataFrame({global_spikes.columns[0]: [0]}),
        global_spikes,
        pd.DataFrame({global_spikes.columns[0]: [duration]/b2.second}),
    ],
    ignore_index=True
)

In [None]:
global_spikes['next_cycle'] = global_spikes['peak_time'].rolling(2).mean().shift(-1)

In [None]:
global_spikes

In [None]:
v_map.spikes.between(
    global_spikes['next_cycle'].iat[0],
    global_spikes['next_cycle'].iat[1]
)

In [None]:
plt.plot(global_rate.index, global_rate['rate']/10)
plt.plot(global_rate.index, global_rate['rate_smooth'])
plt.scatter(global_spikes['t'], global_spikes.rate_smooth, color='red')