Skip to content

Commit

Permalink
ENH: save net.spiketimes as list of list
Browse files Browse the repository at this point in the history
  • Loading branch information
jasmainak committed Jun 19, 2019
1 parent 022a351 commit 14fa808
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
3 changes: 2 additions & 1 deletion examples/plot_simulate_evoked.py
Expand Up @@ -35,14 +35,15 @@
# Now let's simulate the dipole
# You can simulate multiple trials in parallel by using n_jobs > 1
net = Network(params)
dpls = simulate_dipole(net, n_jobs=3, n_trials=3)
dpls = simulate_dipole(net, n_jobs=1, n_trials=2)

###############################################################################
# and then plot it
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 1, sharex=True, figsize=(6, 6))
for dpl in dpls:
dpl.plot(ax=axes[0])
net.plot_input(ax=axes[1])

###############################################################################
# Finally, we can also plot the spikes.
Expand Down
18 changes: 9 additions & 9 deletions mne_neuron/dipole.py
Expand Up @@ -18,13 +18,15 @@ def _hammfilt(x, winsz):

def _clone_and_simulate(params, trial_idx):
from .network import Network

if trial_idx != 0:
params['prng_*'] = trial_idx

net = Network(params, n_jobs=1)

from neuron import h
# create sources and init
net._create_all_src()
net.state_init()
# parallel network connector
net._parnet_connect()
# set to record spikes
net.spiketimes = h.Vector()
Expand All @@ -44,9 +46,6 @@ def _simulate_single_trial(net):
if rank == 0:
print("running on %d cores" % nhosts)

# if trial_idx != 0:
# params['prng_*'] = trial_idx

# global variables, should be node-independent
h("dp_total_L2 = 0.")
h("dp_total_L5 = 0.")
Expand Down Expand Up @@ -117,7 +116,7 @@ def prsimtime():
dpl.convert_fAm_to_nAm()
dpl.scale(net.params['dipole_scalefctr'])
dpl.smooth(net.params['dipole_smooth_win'] / h.dt)
return dpl
return dpl, net.spiketimes.to_python(), net.spikegids.to_python()


def simulate_dipole(net, n_trials=1, n_jobs=1):
Expand All @@ -139,8 +138,10 @@ def simulate_dipole(net, n_trials=1, n_jobs=1):
The dipole object or list of dipole objects if n_trials > 1
"""
parallel, myfunc = _parallel_func(_clone_and_simulate, n_jobs=n_jobs)
dpl = parallel(myfunc(net.params, idx) for idx in range(n_trials))

out = parallel(myfunc(net.params, idx) for idx in range(n_trials))
dpl, spiketimes, spikegids = zip(*out)
net.spiketimes = spiketimes
net.spikegids = spikegids
return dpl

# TODO: add crop method to dipole
Expand Down Expand Up @@ -173,7 +174,6 @@ def __init__(self, times, data): # noqa: D102
self.t = times
self.dpl = {'agg': data[:, 0], 'L2': data[:, 1], 'L5': data[:, 2]}

# conversion from fAm to nAm
def convert_fAm_to_nAm(self):
""" must be run after baseline_renormalization()
"""
Expand Down
17 changes: 13 additions & 4 deletions mne_neuron/network.py
Expand Up @@ -28,11 +28,20 @@ class Network(object):
----------
cells : list of Cell objects.
The list of cells
gid_dict : dict
Dictionary with keys 'evprox1', 'evdist1' etc.
containing the range of Cell IDs of different cell types.
ext_list : dictionary of list of ExtFeed.
Keys are:
'evprox1', 'evprox2', etc.
'evdist1', etc.
'extgauss', 'extpois'
spiketimes : tuple (n_trials, ) of list of float
Each element of the tuple is a trial.
The list contains the time stamps of spikes.
spikegids : tuple (n_trials, ) of list of float
Each element of the tuple is a trial.
The list contains the cell IDs of neurons that spiked.
"""

def __init__(self, params, n_jobs=1):
Expand Down Expand Up @@ -422,8 +431,8 @@ def plot_input(self, ax=None, show=True):
The matplotlib figure handle.
"""
import matplotlib.pyplot as plt
spikes = np.array(self.spiketimes.to_python())
gids = np.array(self.spikegids.to_python())
spikes = np.array(sum(self.spiketimes, []))
gids = np.array(sum(self.spikegids, []))
valid_gids = np.r_[[v for (k, v) in self.gid_dict.items()
if k.startswith('evprox')]]
mask_evprox = np.in1d(gids, valid_gids)
Expand Down Expand Up @@ -458,8 +467,8 @@ def plot_spikes(self, ax=None, show=True):
The matplotlib figure object
"""
import matplotlib.pyplot as plt
spikes = np.array(self.spiketimes.to_python())
gids = np.array(self.spikegids.to_python())
spikes = np.array(sum(self.spiketimes, []))
gids = np.array(sum(self.spikegids, []))
spike_times = np.zeros((4, spikes.shape[0]))
cell_types = ['L5_pyramidal', 'L5_basket', 'L2_pyramidal', 'L2_basket']
for idx, key in enumerate(cell_types):
Expand Down

0 comments on commit 14fa808

Please sign in to comment.