forked from jonescompneurolab/hnn-core
/
dipole.py
281 lines (232 loc) · 8.5 KB
/
dipole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""Class to handle the dipoles."""
# Authors: Mainak Jas <mainak.jas@telecom-paristech.fr>
# Sam Neymotin <samnemo@gmail.com>
import warnings
import numpy as np
from numpy import convolve, hamming
from .viz import plot_dipole
def _hammfilt(x, winsz):
"""Convolve with a hamming window."""
win = hamming(winsz)
win /= sum(win)
return convolve(x, win, 'same')
def simulate_dipole(net, n_trials=None, record_vsoma=False):
"""Simulate a dipole given the experiment parameters.
Parameters
----------
net : Network object
The Network object specifying how cells are
connected.
n_trials : int | None
The number of trials to simulate. If None the value in
net.params['N_trials'] will be used
record_vsoma : bool
Option to record somatic voltages from cells
Returns
-------
dpls: list
List of dipole objects for each trials
"""
from .parallel_backends import _BACKEND, JoblibBackend
if _BACKEND is None:
_BACKEND = JoblibBackend(n_jobs=1)
if n_trials is not None:
net.params['N_trials'] = n_trials
# need to redo these if n_trials changed after net.__init__()!
net._instantiate_feeds(n_trials=n_trials)
else:
n_trials = net.params['N_trials']
if n_trials < 1:
raise ValueError("Invalid number of simulations: %d" % n_trials)
if record_vsoma is not None and isinstance(record_vsoma, bool):
net.params['record_vsoma'] = record_vsoma
else:
raise TypeError("record_vsoma must be bool, got %s"
% type(record_vsoma).__name__)
dpls = _BACKEND.simulate(net)
return dpls
def read_dipole(fname, units='nAm'):
"""Read dipole values from a file and create a Dipole instance.
Parameters
----------
fname : str
Full path to the input file (.txt)
Returns
-------
dpl : Dipole
The instance of Dipole class
"""
dpl_data = np.loadtxt(fname, dtype=float)
dpl = Dipole(dpl_data[:, 0], dpl_data[:, 1:4])
if units == 'nAm':
dpl.units = units
return dpl
def average_dipoles(dpls):
"""Compute dipole averages over a list of Dipole objects.
Parameters
----------
dpls: list of Dipole objects
Contains list of dipole objects, each with a `data` member containing
'L2', 'L5' and 'agg' components
Returns
-------
dpl: instance of Dipole
A new dipole object with each component of `dpl.data` representing the
average over the same components in the input list
"""
# need at least one Dipole to get times
if len(dpls) < 2:
raise ValueError("Need at least two dipole object to compute an"
" average")
for dpl_idx, dpl in enumerate(dpls):
if dpl.nave > 1:
raise ValueError("Dipole at index %d was already an average of %d"
" trials. Cannot reaverage" %
(dpl_idx, dpl.nave))
agg_avg = np.mean(np.array([dpl.data['agg'] for dpl in dpls]), axis=0)
L2_avg = np.mean(np.array([dpl.data['L2'] for dpl in dpls]), axis=0)
L5_avg = np.mean(np.array([dpl.data['L5'] for dpl in dpls]), axis=0)
avg_dpl_data = np.c_[agg_avg,
L2_avg,
L5_avg]
avg_dpl = Dipole(dpls[0].times, avg_dpl_data)
# set nave to the number of trials averaged in this dipole
avg_dpl.nave = len(dpls)
return avg_dpl
class Dipole(object):
"""Dipole class.
Parameters
----------
times : array (n_times,)
The time vector
data : array (n_times x 3)
The data. The first column represents 'agg',
the second 'L2' and the last one 'L5'
nave : int
Number of trials that were averaged to produce this Dipole. Defaults
to 1
Attributes
----------
times : array
The time vector
data : dict of array
The dipole with keys 'agg', 'L2' and 'L5'
nave : int
Number of trials that were averaged to produce this Dipole
"""
def __init__(self, times, data, nave=1): # noqa: D102
self.units = 'fAm'
self.N = data.shape[0]
self.times = times
self.data = {'agg': data[:, 0], 'L2': data[:, 1], 'L5': data[:, 2]}
self.nave = nave
def convert_fAm_to_nAm(self):
""" must be run after baseline_renormalization()
"""
for key in self.data.keys():
self.data[key] *= 1e-6
self.units = 'nAm'
def scale(self, fctr):
for key in self.data.keys():
self.data[key] *= fctr
return fctr
def smooth(self, winsz):
# XXX: add check to make sure self.times is
# not smaller than winsz
if winsz <= 1:
return
for key in self.data.keys():
self.data[key] = _hammfilt(self.data[key], winsz)
def plot(self, ax=None, layer='agg', show=True):
"""Simple layer-specific plot function.
Parameters
----------
ax : instance of matplotlib figure | None
The matplotlib axis
layer : str
The layer to plot. Can be one of
'agg', 'L2', and 'L5'
show : bool
If True, show the figure
Returns
-------
fig : instance of plt.fig
The matplotlib figure handle.
"""
return plot_dipole(dpl=self, ax=ax, layer=layer, show=show)
def baseline_renormalize(self, params):
"""Only baseline renormalize if the units are fAm.
Parameters
----------
params : dict
The parameters
"""
if self.units != 'fAm':
print("Warning, no dipole renormalization done because units"
" were in %s" % (self.units))
return
N_pyr_x = params['N_pyr_x']
N_pyr_y = params['N_pyr_y']
# N_pyr cells in grid. This is PER LAYER
N_pyr = N_pyr_x * N_pyr_y
# dipole offset calculation: increasing number of pyr
# cells (L2 and L5, simultaneously)
# with no inputs resulted in an aggregate dipole over the
# interval [50., 1000.] ms that
# eventually plateaus at -48 fAm. The range over this interval
# is something like 3 fAm
# so the resultant correction is here, per dipole
# dpl_offset = N_pyr * 50.207
dpl_offset = {
# these values will be subtracted
'L2': N_pyr * 0.0443,
'L5': N_pyr * -49.0502
# 'L5': N_pyr * -48.3642,
# will be calculated next, this is a placeholder
# 'agg': None,
}
# L2 dipole offset can be roughly baseline shifted over
# the entire range of t
self.data['L2'] -= dpl_offset['L2']
# L5 dipole offset should be different for interval [50., 500.]
# and then it can be offset
# slope (m) and intercept (b) params for L5 dipole offset
# uncorrected for N_cells
# these values were fit over the range [37., 750.)
m = 3.4770508e-3
b = -51.231085
# these values were fit over the range [750., 5000]
t1 = 750.
m1 = 1.01e-4
b1 = -48.412078
# piecewise normalization
self.data['L5'][self.times <= 37.] -= dpl_offset['L5']
self.data['L5'][(self.times > 37.) & (self.times < t1)] -= N_pyr * \
(m * self.times[(self.times > 37.) & (self.times < t1)] + b)
self.data['L5'][self.times >= t1] -= N_pyr * \
(m1 * self.times[self.times >= t1] + b1)
# recalculate the aggregate dipole based on the baseline
# normalized ones
self.data['agg'] = self.data['L2'] + self.data['L5']
def write(self, fname):
"""Write dipole values to a file.
Parameters
----------
fname : str
Full path to the output file (.txt)
Outputs
-------
A tab separatd txt file where rows correspond
to samples and columns correspond to
1) time (s),
2) aggregate current dipole (scaled nAm),
3) L2/3 current dipole (scaled nAm), and
4) L5 current dipole (scaled nAm)
"""
if self.nave > 1:
warnings.warn("Saving Dipole to file that is an average of %d"
" trials" % self.nave)
X = np.r_[[self.times, self.data['agg'], self.data['L2'],
self.data['L5']]].T
np.savetxt(fname, X, fmt=['%3.3f', '%5.4f', '%5.4f', '%5.4f'],
delimiter='\t')