Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions phylib/io/alf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
# File utils
#------------------------------------------------------------------------------

NCH_WAVEFORMS = 32 # number of channels to be saved in templates.waveforms and channels.waveforms
NSAMPLE_WAVEFORMS = 500 # number of waveforrms sampled out of the raw data

_FILE_RENAMES = [ # file_in, file_out, squeeze (bool to squeeze vector from matlab in npy)
('params.py', 'params.py', None),
('cluster_metrics.csv', 'clusters.metrics.csv', None),
('cluster_KSLabel.tsv', 'cluster_KSLabel.tsv', None),
('spike_clusters.npy', 'spikes.clusters.npy', True),
('spike_templates.npy', 'spikes.templates.npy', True),
('channel_positions.npy', 'channels.localCoordinates.npy', False),
Expand All @@ -42,6 +42,9 @@
('_phy_spikes_subset.channels.npy', '_phy_spikes_subset.channels.npy', False),
('_phy_spikes_subset.spikes.npy', '_phy_spikes_subset.spikes.npy', False),
('_phy_spikes_subset.waveforms.npy', '_phy_spikes_subset.waveforms.npy', False),
('drift.depth_scale.npy', 'drift.depth_scale.npy', False),
('drift.time_scale.npy', 'drift.time_scale.npy', False),
('drift.um.npy', 'drift.um.npy', False),
# ('cluster_group.tsv', 'ks2/clusters.phyAnnotation.tsv', False), # todo check indexing, add2QC
]

Expand Down Expand Up @@ -116,21 +119,23 @@ def convert(self, out_path, force=False, label='', ampfactor=1):
if not self.out_path.exists():
self.out_path.mkdir()

with tqdm(desc="Converting to ALF", total=95) as bar:
self.copy_files(force=force)
bar.update(10)
self.make_spike_times_amplitudes()
with tqdm(desc="Converting to ALF", total=125) as bar:
bar.update(10)
self.make_cluster_objects()
bar.update(10)
self.make_channel_objects()
bar.update(5)
self.make_template_and_spikes_objects()
bar.update(30)
self.model.save_spikes_subset_waveforms(
NSAMPLE_WAVEFORMS, sample2unit=self.ampfactor)
bar.update(50)
self.make_depths()
bar.update(20)
self.make_template_object()
bar.update(30)
self.rm_files()
bar.update(10)
self.copy_files(force=force)
bar.update(10)
self.rename_with_label()

# Return the TemplateModel of the converted ALF dataset if the params.py file exists.
Expand Down Expand Up @@ -165,16 +170,8 @@ def _save_npy(self, filename, arr):
"""Save an array into a .npy file."""
np.save(self.out_path / filename, arr)

def make_spike_times_amplitudes(self):
"""We cannot just rename/copy spike_times.npy because it is in unit of
*samples*, and not in seconds."""
self._save_npy('spikes.times.npy', self.model.spike_times)
self._save_npy('spikes.samples.npy', self.model.spike_samples)
self._save_npy('spikes.amps.npy', self.model.get_amplitudes_true() * self.ampfactor)

def make_cluster_objects(self):
"""Create clusters.channels, clusters.waveformsDuration and clusters.amps"""

peak_channel_path = self.dir_path / 'clusters.channels.npy'
if not peak_channel_path.exists():
self._save_npy(peak_channel_path.name, self.model.templates_channels)
Expand All @@ -184,8 +181,8 @@ def make_cluster_objects(self):
self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations)

# group by average over cluster number
camps = np.zeros(np.max(self.cluster_ids) - np.min(self.cluster_ids) + 1,) * np.nan
camps[self.cluster_ids - np.min(self.cluster_ids)] = self.model.templates_amplitudes
camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan
camps[self.cluster_ids] = self.model.templates_amplitudes
amps_path = self.dir_path / 'clusters.amps.npy'
self._save_npy(amps_path.name, camps * self.ampfactor)

Expand All @@ -198,7 +195,7 @@ def make_cluster_objects(self):
def make_channel_objects(self):
"""If there is no rawInd file, create it"""
rawInd_path = self.dir_path / 'channels.rawInd.npy'
rawInd = np.zeros_like(self.model.channel_probes).astype(np.int)
rawInd = np.zeros_like(self.model.channel_probes).astype(int)
channel_offset = 0
for probe in np.unique(self.model.channel_probes):
ind = self.model.channel_probes == probe
Expand All @@ -225,20 +222,27 @@ def make_depths(self):
spikes_depths = clusters_depths[spike_clusters]
else:
spikes_depths = self.model.get_depths()
# if PC features are provided, compute the depth as the weighted sum of coordinates

self._save_npy('spikes.depths.npy', spikes_depths)
self._save_npy('clusters.depths.npy', clusters_depths)

def make_template_object(self):
def make_template_and_spikes_objects(self):
"""Creates the template waveforms sparse object
Without manual curation, it also corresponds to clusters waveforms objects.
"""
# "We cannot just rename/copy spike_times.npy because it is in unit of samples,
# and not seconds
self._save_npy('spikes.times.npy', self.model.spike_times)
self._save_npy('spikes.samples.npy', self.model.spike_samples)
spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor)
self._save_npy('spikes.amps.npy', spike_amps)
self._save_npy('templates.amps.npy', template_amps)

if self.model.sparse_templates.cols:
raise(NotImplementedError("Sparse template export to ALF not implemented yet"))
else:
n_templates, n_wavsamps, nchall = self.model.sparse_templates.data.shape
ncw = min(NCH_WAVEFORMS, nchall) # for some datasets, 32 may be too much
n_templates, n_wavsamps, nchall = templates_v.shape
# for some datasets, 32 may be too much
ncw = min(self.model.n_closest_channels, nchall)
assert(n_templates == self.model.n_templates)
templates = np.zeros((n_templates, n_wavsamps, ncw), dtype=np.float32)
templates_inds = np.zeros((n_templates, ncw), dtype=np.int32)
Expand All @@ -250,10 +254,10 @@ def make_template_object(self):
self.model.channel_positions[self.model.templates_channels[t]]), axis=1)
channel_distance[self.model.channel_probes != current_probe] += np.inf
templates_inds[t, :] = np.argsort(channel_distance)[:ncw]
templates[t, ...] = self.model.sparse_templates.data[t, :][:, templates_inds[t, :]]
np.save(self.out_path.joinpath('templates.waveforms'), templates * self.ampfactor)
templates[t, ...] = templates_v[t, :][:, templates_inds[t, :]]
np.save(self.out_path.joinpath('templates.waveforms'), templates)
np.save(self.out_path.joinpath('templates.waveformsChannels'), templates_inds)
np.save(self.out_path.joinpath('clusters.waveforms'), templates * self.ampfactor)
np.save(self.out_path.joinpath('clusters.waveforms'), templates)
np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds)

def rename_with_label(self):
Expand Down
4 changes: 2 additions & 2 deletions phylib/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _index_of(arr, lookup):
# values
lookup = np.asarray(lookup, dtype=np.int32)
m = (lookup.max() if len(lookup) else 0) + 1
tmp = np.zeros(m + 1, dtype=np.int)
tmp = np.zeros(m + 1, dtype=int)
# Ensure that -1 values are kept.
tmp[-1] = -1
if len(lookup):
Expand Down Expand Up @@ -327,7 +327,7 @@ def get_excerpts(data, n_excerpts=None, excerpt_size=None):
def _spikes_in_clusters(spike_clusters, clusters):
"""Return the ids of all spikes belonging to the specified clusters."""
if len(spike_clusters) == 0 or len(clusters) == 0:
return np.array([], dtype=np.int)
return np.array([], dtype=int)
return np.nonzero(np.in1d(spike_clusters, clusters))[0]


Expand Down
89 changes: 59 additions & 30 deletions phylib/io/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,44 +1035,70 @@ def get_template_features(self, spike_ids):
def get_depths(self):
"""Compute spike depths based on spike pc features and probe depths."""
# compute the depth as the weighted sum of coordinates
batch_sz = 50000 # number of spikes per batch
# if PC features are provided, compute the depth as the weighted sum of coordinates
nbatch = 50000
c = 0
spike_depths = np.zeros_like(self.spike_times)
nspi = spike_depths.shape[0]
spikes_depths = np.zeros_like(self.spike_times) * np.nan
nspi = spikes_depths.shape[0]
if self.sparse_features is None or self.sparse_features.data.shape[0] != self.n_spikes:
return None
while True:
ispi = np.arange(c, min(c + batch_sz, nspi))
ispi = np.arange(c, min(c + nbatch, nspi))
# take only first component
features = np.square(self.sparse_features.data[ispi, :, 0])
ichannels = self.sparse_features.cols[self.spike_clusters[ispi]].astype(np.int64)
features = self.sparse_features.data[ispi, :, 0]
features = np.maximum(features, 0) ** 2 # takes only positive values into account
ichannels = self.sparse_features.cols[self.spike_clusters[ispi]].astype(np.uint32)
ypos = self.channel_positions[ichannels, 1]

spike_depths[ispi] = np.sum(np.transpose(ypos * features) /
np.sum(features, axis=1), axis=0)
c += batch_sz
with np.errstate(divide='ignore'):
spikes_depths[ispi] = (np.sum(np.transpose(ypos * features) /
np.sum(features, axis=1), axis=0))
c += nbatch
if c >= nspi:
break
return spikes_depths

return spike_depths

def get_amplitudes_true(self):
def get_amplitudes_true(self, sample2unit=1.):
"""Convert spike amplitude values to input amplitudes units
via scaling by unwhitened template waveform."""
# unwhiten template waveforms on their channels of max amplitude
templates_chs = self.templates_channels
templates_wfs = self.sparse_templates.data[np.arange(self.n_templates), :, templates_chs]
templates_wfs_unw = templates_wfs.T * self.wmi[templates_chs, templates_chs]
templates_amps = np.abs(
np.max(templates_wfs_unw, axis=0) - np.min(templates_wfs_unw, axis=0))
via scaling by unwhitened template waveform.
:param sample2unit float: factor to convert the raw data to a physical unit (defaults 1.)
:returns: spike_amplitudes_volts: np.array [nspikes] spike amplitudes in raw data units
:returns: templates_volts: np.array[ntemplates, nsamples, nchannels]: templates
in raw data units
:returns: template_amps_volts: np.array[ntemplates]: average templates amplitudes
in raw data units
To scale the template for template matching,
raw_data_volts = templates_volts * spike_amplitudes_volts / template_amps_volts
"""
# spike_amp = ks2_spike_amps * maxmin(inv_whitening(ks2_template_amps))
# to rescale the template,

# scale the spike amplitude values by the template amplitude values
amplitudes_v = np.zeros_like(self.amplitudes)
for t in range(self.n_templates):
idxs = self.get_template_spikes(t)
amplitudes_v[idxs] = self.amplitudes[idxs] * templates_amps[t]

return amplitudes_v
# unwhiten template waveforms on their channels of max amplitude
if self.sparse_templates.cols:
raise NotImplementedError
# apply the inverse whitening matrix to the template
templates_wfs = np.zeros_like(self.sparse_templates.data) # nt, ns, nc
for n in np.arange(self.n_templates):
templates_wfs[n, :, :] = np.matmul(self.sparse_templates.data[n, :, :], self.wmi)

# The amplitude on each channel is the positive peak minus the negative
templates_ch_amps = np.max(templates_wfs, axis=1) - np.min(templates_wfs, axis=1)

# The template arbitrary unit amplitude is the amplitude of its largest channel
# (but see below for true tempAmps)
templates_amps_au = np.max(templates_ch_amps, axis=1)
spike_amps = templates_amps_au[self.spike_templates] * self.amplitudes

with np.errstate(divide='ignore'):
# take the average spike amplitude per template
templates_amps_v = (np.bincount(self.spike_templates, weights=spike_amps) /
np.bincount(self.spike_templates))
# scale back the template according to the spikes units
templates_physical_unit = templates_wfs * (templates_amps_v / templates_amps_au
)[:, np.newaxis, np.newaxis]

return (spike_amps * sample2unit,
templates_physical_unit * sample2unit,
templates_amps_v * sample2unit)

#--------------------------------------------------------------------------
# Internal helper methods for public high-level methods
Expand Down Expand Up @@ -1232,15 +1258,18 @@ def save_spike_clusters(self, spike_clusters):
logger.debug("Save spike clusters to `%s`.", path)
np.save(path, spike_clusters)

def save_spikes_subset_waveforms(self, max_n_spikes_per_template=None, max_n_channels=None):
def save_spikes_subset_waveforms(self, max_n_spikes_per_template=None, max_n_channels=None,
sample2unit=1.):
if self.traces is None:
logger.warning(
"Spike waveforms could not be extracted as the raw data file is not available.")
return

n_chunks_kept = 20 # TODO: better choice
nst = max_n_spikes_per_template
nc = max_n_channels
nc = max_n_channels or self.n_closest_channels
nc = max(nc, self.n_closest_channels)

assert nst > 0
assert nc > 0

Expand Down Expand Up @@ -1275,7 +1304,7 @@ def save_spikes_subset_waveforms(self, max_n_spikes_per_template=None, max_n_cha
# Extract waveforms from the raw data on a chunk by chunk basis.
export_waveforms(
path, self.traces, self.spike_samples[spike_ids], spike_channels,
n_samples_waveforms=self.n_samples_waveforms)
n_samples_waveforms=self.n_samples_waveforms, sample2unit=sample2unit)

# Reload spike waveforms.
self.spike_waveforms = self._load_spike_waveforms()
Expand Down
14 changes: 10 additions & 4 deletions phylib/io/tests/test_alf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, tempdir):
self.nt = 5
self.ncd = 1000
np.save(p / 'spike_times.npy', .01 * np.cumsum(nr.exponential(size=self.ns)))
np.save(p / 'spike_clusters.npy', nr.randint(low=0, high=self.nt, size=self.ns))
np.save(p / 'spike_clusters.npy', nr.randint(low=1, high=self.nt, size=self.ns))
shutil.copy(p / 'spike_clusters.npy', p / 'spike_templates.npy')
np.save(p / 'amplitudes.npy', nr.uniform(low=0.5, high=1.5, size=self.ns))
np.save(p / 'channel_positions.npy', np.c_[np.arange(self.nc), np.zeros(self.nc)])
Expand Down Expand Up @@ -174,16 +174,22 @@ def check_conversion_output():
assert f.exists()

# makes sure the output dimensions match (especially clusters which should be 4)
cl_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('clusters.') and
f.name.endswith('.npy')]
cl_shape = []
for f in new_files:
if f.name.startswith('clusters.') and f.name.endswith('.npy'):
cl_shape.append(np.load(f).shape[0])
elif f.name.startswith('clusters.') and f.name.endswith('.csv'):
with open(f) as fid:
cl_shape.append(len(fid.readlines()) - 1)
sp_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('spikes.')]
ch_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('channels.')]

assert len(set(cl_shape)) == 1
assert len(set(sp_shape)) == 1
assert len(set(ch_shape)) == 1

dur = np.load(next(out_path.glob('clusters.peakToTrough*.npy')))
assert np.all(dur == np.array([18., -1., 9.5, 2.5, -2.]))
assert np.all(dur == np.array([-9.5, 3., 13., -4.5, -2.5]))

def read_after_write():
model = TemplateModel(dir_path=out_path, dat_path=dataset.dat_path,
Expand Down
9 changes: 5 additions & 4 deletions phylib/io/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,20 +641,21 @@ def iter_waveforms(traces, spike_samples, spike_channels, n_samples_waveforms=No


def export_waveforms(
path, traces, spike_samples, spike_channels, n_samples_waveforms=None, cache=False):
path, traces, spike_samples, spike_channels, n_samples_waveforms=None, cache=False,
sample2unit=1):
"""Export a selection of spike waveforms to a npy file by iterating over the data on a chunk
by chunk basis."""
n_spikes = len(spike_samples)
spike_channels = np.asarray(spike_channels, dtype=np.int32)
n_channels_loc = spike_channels.shape[1]
shape = (n_spikes, n_samples_waveforms, n_channels_loc)

writer = NpyWriter(path, shape, traces.dtype)
dtype = traces.dtype if sample2unit is None else float
writer = NpyWriter(path, shape, dtype)
size_written = 0
for waveforms in iter_waveforms(
traces, spike_samples, spike_channels, n_samples_waveforms=n_samples_waveforms,
cache=cache):
writer.append(waveforms)
writer.append(waveforms * sample2unit)
size_written += waveforms.size
writer.close()
assert prod(shape) == size_written
2 changes: 1 addition & 1 deletion phylib/stats/ccg.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def correlograms(

# At a given shift, the mask precises which spikes have matching spikes
# within the correlogram time window.
mask = np.ones_like(spike_samples, dtype=np.bool)
mask = np.ones_like(spike_samples, dtype=bool)

correlograms = _create_correlograms_array(n_clusters, winsize_bins)

Expand Down
4 changes: 2 additions & 2 deletions phylib/utils/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#------------------------------------------------------------------------------

_ACCEPTED_ARRAY_DTYPES = (
np.float, np.float32, np.float64, np.int, np.int8, np.int16, np.uint8, np.uint16,
np.int32, np.int64, np.uint32, np.uint64, np.bool)
float, np.float32, np.float64, int, np.int8, np.int16, np.uint8, np.uint16,
np.int32, np.int64, np.uint32, np.uint64, bool)


class Bunch(dict):
Expand Down
10 changes: 5 additions & 5 deletions phylib/utils/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ def _check(arr):
_check(_as_array(3.))
_check(_as_array([3]))

_check(_as_array(3, np.float))
_check(_as_array(3., np.float))
_check(_as_array([3], np.float))
_check(_as_array(3, float))
_check(_as_array(3., float))
_check(_as_array([3], float))
_check(_as_array(np.array([3])))
with raises(ValueError):
_check(_as_array(np.array([3]), dtype=np.object))
_check(_as_array(np.array([3]), np.float))
_check(_as_array(np.array([3]), dtype=object))
_check(_as_array(np.array([3]), float))

assert _as_array(None) is None
assert not _is_array_like(None)
Expand Down