Skip to content

Commit

Permalink
Started tests on real data.
Browse files Browse the repository at this point in the history
  • Loading branch information
rossant committed Nov 27, 2013
1 parent e6376da commit ef5da69
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 29 deletions.
4 changes: 2 additions & 2 deletions dev/file_format/experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"prm = {\n",
" 'nchannels': nchannels,\n",
" 'nfeatures': nchannels*3,\n",
" 'nwavesamples': 20\n",
" 'waveforms_nsamples': 20\n",
" }\n",
"prb = {'channel_groups': [\n",
" {\n",
Expand Down Expand Up @@ -316,7 +316,7 @@
"cell_type": "code",
"collapsed": false,
"input": [
"exp.application_data.spikedetekt.nwavesamples"
"exp.application_data.spikedetekt.waveforms_nsamples"
],
"language": "python",
"metadata": {},
Expand Down
3 changes: 1 addition & 2 deletions spikedetekt2/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ def run(raw_data=None, experiment=None, prm=None, probe=None):
for chunk in raw_data.chunks(chunk_size=chunk_size,
chunk_overlap=chunk_overlap,):
# Filter the (full) chunk.
# shape: (nsamples, nchannels)
chunk_raw = chunk.data_chunk_full
chunk_raw = chunk.data_chunk_full # shape: (nsamples, nchannels)
chunk_fil = apply_filter(chunk_raw, filter=filter)

# Apply thresholds.
Expand Down
4 changes: 0 additions & 4 deletions spikedetekt2/core/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,13 @@

sample_rate = 20000
duration = 1.
nwavesamples = 10
nchannels = 2
chunk_size = 20000
nsamples = int(sample_rate * duration)
raw_data = .1 * np.random.randn(nsamples, nchannels)

prm = get_params(**{
'nwavesamples': nwavesamples,
'nchannels': nchannels,
'sample_rate': sample_rate,
'chunk_size': chunk_size,
'detect_spikes': 'positive',
})
prb = {'channel_groups': [
Expand Down
4 changes: 3 additions & 1 deletion spikedetekt2/dataio/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def _get_child_id(child):
def _print_instance(obj, depth=0, name=''):
# Handle the first element of the list/dict.
if isinstance(obj, (list, dict)):
if not obj:
r = []
return r
if isinstance(obj, list):
sobj = obj[0]
key = '0'
Expand Down Expand Up @@ -293,7 +296,6 @@ def add(self, time_samples=None, time_fractional=0,
self.cluster.append((cluster,))
self.cluster_original.append((cluster_original,))
self.features_masks.append(features_masks)

self.waveforms_raw.append(waveforms_raw)
self.waveforms_filtered.append(waveforms_filtered)

Expand Down
12 changes: 6 additions & 6 deletions spikedetekt2/dataio/kwik.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def create_kwx(path, prb=None, prm=None, has_masks=True):
Arguments:
* prb: the PRB dictionary
* nwavesamples (common to all channel groups if set)
* waveforms_nsamples (common to all channel groups if set)
* nfeatures (total number of features per spike, common to all channel groups if set)
* nchannels (number of channels per channel group, common to all channel groups if set)
Expand All @@ -185,20 +185,20 @@ def create_kwx(path, prb=None, prm=None, has_masks=True):
nchannels = prm.get('nchannels', None)
nfeatures_per_channel = prm.get('nfeatures_per_channel', None)
nfeatures = prm.get('nfeatures', None)
nwavesamples = prm.get('nwavesamples', None)
waveforms_nsamples = prm.get('waveforms_nsamples', None)

file = tb.openFile(path, mode='w')
file.createGroup('/', 'channel_groups')

for ichannel_group, chgrp_info in enumerate(prb.get('channel_groups', [])):
nchannels_ = len(chgrp_info.get('channels', [])) or nchannels or 0
nwavesamples_ = chgrp_info.get('nwavesamples', nwavesamples) or 0
waveforms_nsamples_ = chgrp_info.get('waveforms_nsamples', waveforms_nsamples) or 0
nfeatures_per_channel_ = chgrp_info.get('nfeatures_per_channel', nfeatures_per_channel) or 0
nfeatures_ = chgrp_info.get('nfeatures', nfeatures) or nfeatures_per_channel_ * nchannels_

assert nchannels_ > 0
assert nfeatures_ > 0
assert nwavesamples_ > 0
assert waveforms_nsamples_ > 0

channel_group_path = '/channel_groups/{0:d}'.format(ichannel_group)

Expand All @@ -216,9 +216,9 @@ def create_kwx(path, prb=None, prm=None, has_masks=True):
tb.Float32Atom(), (0, nfeatures_))

file.createEArray(channel_group_path, 'waveforms_raw',
tb.Int16Atom(), (0, nwavesamples_, nchannels_))
tb.Int16Atom(), (0, waveforms_nsamples_, nchannels_))
file.createEArray(channel_group_path, 'waveforms_filtered',
tb.Int16Atom(), (0, nwavesamples_, nchannels_))
tb.Int16Atom(), (0, waveforms_nsamples_, nchannels_))

file.close()

Expand Down
4 changes: 2 additions & 2 deletions spikedetekt2/dataio/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def setup():
# Create files.
prm = {'nfeatures': 3, 'nwavesamples': 10, 'nchannels': 3}
prm = {'nfeatures': 3, 'waveforms_nsamples': 10, 'nchannels': 3}
prb = {'channel_groups': [
{
'channels': [4, 6, 8],
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_experiment_channels():
assert exp.application_data
assert exp.user_data
assert exp.application_data.spikedetekt.nchannels == 3
assert exp.application_data.spikedetekt.nwavesamples == 10
assert exp.application_data.spikedetekt.waveforms_nsamples == 10
assert exp.application_data.spikedetekt.nfeatures == 3

# Channel group.
Expand Down
16 changes: 8 additions & 8 deletions spikedetekt2/dataio/tests/test_kwik.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
DIRPATH = tempfile.mkdtemp()

def setup_create():
prm = {'nfeatures': 3, 'nwavesamples': 10}
prm = {'nfeatures': 3, 'waveforms_nsamples': 10}
prb = {'channel_groups': [
{
'channels': [4, 6, 8],
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_create_kwik():
path = os.path.join(DIRPATH, 'myexperiment.kwik')

prm = {
'nwavesamples': 20,
'waveforms_nsamples': 20,
'nfeatures': 3*32,
}
prb = {'channel_groups': [
Expand All @@ -99,12 +99,12 @@ def test_create_kwx():
path = os.path.join(DIRPATH, 'myexperiment.kwx')

# Create the KWX file.
nwavesamples = 20
waveforms_nsamples = 20
nchannels = 32
nchannels2 = 24
nfeatures = 3*nchannels
prm = {
'nwavesamples': 20,
'waveforms_nsamples': 20,
'nfeatures': 3*nchannels,
}
prb = {'channel_groups': [
Expand All @@ -131,16 +131,16 @@ def test_create_kwx():
wr1 = f.root.channel_groups.__getattr__('1').waveforms_raw
wf1 = f.root.channel_groups.__getattr__('1').waveforms_filtered
assert fm1.shape[1:] == (3*nchannels2, 2)
assert wr1.shape[1:] == (nwavesamples, nchannels2)
assert wf1.shape[1:] == (nwavesamples, nchannels2)
assert wr1.shape[1:] == (waveforms_nsamples, nchannels2)
assert wf1.shape[1:] == (waveforms_nsamples, nchannels2)

# Group 2
fm2 = f.root.channel_groups.__getattr__('2').features_masks
wr2 = f.root.channel_groups.__getattr__('2').waveforms_raw
wf2 = f.root.channel_groups.__getattr__('2').waveforms_filtered
assert fm2.shape[1:] == (2*nchannels, 2)
assert wr2.shape[1:] == (nwavesamples, nchannels)
assert wf2.shape[1:] == (nwavesamples, nchannels)
assert wr2.shape[1:] == (waveforms_nsamples, nchannels)
assert wf2.shape[1:] == (waveforms_nsamples, nchannels)

f.close()

Expand Down
9 changes: 5 additions & 4 deletions spikedetekt2/utils/params_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

# Chunks
# ------
chunk_size = int(sample_rate)
chunk_overlap = int(.01 * sample_rate)
chunk_size = int(1. * sample_rate) # 1 second
chunk_overlap = int(.01 * sample_rate) # 10 ms

# Spike detection
# ---------------
Expand All @@ -27,8 +27,9 @@

# Spike extraction
# ----------------
extract_s_before = 5
extract_s_after = 5
extract_s_before = 10
extract_s_after = 10
waveforms_nsamples = extract_s_before + extract_s_after

# Features
# --------
Expand Down

0 comments on commit ef5da69

Please sign in to comment.