diff --git a/phylib/io/alf.py b/phylib/io/alf.py index 73ebcea..ce2b9a4 100644 --- a/phylib/io/alf.py +++ b/phylib/io/alf.py @@ -27,11 +27,11 @@ # File utils #------------------------------------------------------------------------------ -NCH_WAVEFORMS = 32 # number of channels to be saved in templates.waveforms and channels.waveforms +NSAMPLE_WAVEFORMS = 500 # number of waveforrms sampled out of the raw data _FILE_RENAMES = [ # file_in, file_out, squeeze (bool to squeeze vector from matlab in npy) ('params.py', 'params.py', None), - ('cluster_metrics.csv', 'clusters.metrics.csv', None), + ('cluster_KSLabel.tsv', 'cluster_KSLabel.tsv', None), ('spike_clusters.npy', 'spikes.clusters.npy', True), ('spike_templates.npy', 'spikes.templates.npy', True), ('channel_positions.npy', 'channels.localCoordinates.npy', False), @@ -42,6 +42,9 @@ ('_phy_spikes_subset.channels.npy', '_phy_spikes_subset.channels.npy', False), ('_phy_spikes_subset.spikes.npy', '_phy_spikes_subset.spikes.npy', False), ('_phy_spikes_subset.waveforms.npy', '_phy_spikes_subset.waveforms.npy', False), + ('drift.depth_scale.npy', 'drift.depth_scale.npy', False), + ('drift.time_scale.npy', 'drift.time_scale.npy', False), + ('drift.um.npy', 'drift.um.npy', False), # ('cluster_group.tsv', 'ks2/clusters.phyAnnotation.tsv', False), # todo check indexing, add2QC ] @@ -116,21 +119,23 @@ def convert(self, out_path, force=False, label='', ampfactor=1): if not self.out_path.exists(): self.out_path.mkdir() - with tqdm(desc="Converting to ALF", total=95) as bar: - self.copy_files(force=force) - bar.update(10) - self.make_spike_times_amplitudes() + with tqdm(desc="Converting to ALF", total=125) as bar: bar.update(10) self.make_cluster_objects() bar.update(10) self.make_channel_objects() bar.update(5) + self.make_template_and_spikes_objects() + bar.update(30) + self.model.save_spikes_subset_waveforms( + NSAMPLE_WAVEFORMS, sample2unit=self.ampfactor) + bar.update(50) self.make_depths() bar.update(20) - self.make_template_object() - bar.update(30) self.rm_files() bar.update(10) + self.copy_files(force=force) + bar.update(10) self.rename_with_label() # Return the TemplateModel of the converted ALF dataset if the params.py file exists. @@ -165,16 +170,8 @@ def _save_npy(self, filename, arr): """Save an array into a .npy file.""" np.save(self.out_path / filename, arr) - def make_spike_times_amplitudes(self): - """We cannot just rename/copy spike_times.npy because it is in unit of - *samples*, and not in seconds.""" - self._save_npy('spikes.times.npy', self.model.spike_times) - self._save_npy('spikes.samples.npy', self.model.spike_samples) - self._save_npy('spikes.amps.npy', self.model.get_amplitudes_true() * self.ampfactor) - def make_cluster_objects(self): """Create clusters.channels, clusters.waveformsDuration and clusters.amps""" - peak_channel_path = self.dir_path / 'clusters.channels.npy' if not peak_channel_path.exists(): self._save_npy(peak_channel_path.name, self.model.templates_channels) @@ -184,8 +181,8 @@ def make_cluster_objects(self): self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations) # group by average over cluster number - camps = np.zeros(np.max(self.cluster_ids) - np.min(self.cluster_ids) + 1,) * np.nan - camps[self.cluster_ids - np.min(self.cluster_ids)] = self.model.templates_amplitudes + camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan + camps[self.cluster_ids] = self.model.templates_amplitudes amps_path = self.dir_path / 'clusters.amps.npy' self._save_npy(amps_path.name, camps * self.ampfactor) @@ -198,7 +195,7 @@ def make_cluster_objects(self): def make_channel_objects(self): """If there is no rawInd file, create it""" rawInd_path = self.dir_path / 'channels.rawInd.npy' - rawInd = np.zeros_like(self.model.channel_probes).astype(np.int) + rawInd = np.zeros_like(self.model.channel_probes).astype(int) channel_offset = 0 for probe in np.unique(self.model.channel_probes): ind = self.model.channel_probes == probe @@ -225,20 +222,27 @@ def make_depths(self): spikes_depths = clusters_depths[spike_clusters] else: spikes_depths = self.model.get_depths() - # if PC features are provided, compute the depth as the weighted sum of coordinates - self._save_npy('spikes.depths.npy', spikes_depths) self._save_npy('clusters.depths.npy', clusters_depths) - def make_template_object(self): + def make_template_and_spikes_objects(self): """Creates the template waveforms sparse object Without manual curation, it also corresponds to clusters waveforms objects. """ + # "We cannot just rename/copy spike_times.npy because it is in unit of samples, + # and not seconds + self._save_npy('spikes.times.npy', self.model.spike_times) + self._save_npy('spikes.samples.npy', self.model.spike_samples) + spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor) + self._save_npy('spikes.amps.npy', spike_amps) + self._save_npy('templates.amps.npy', template_amps) + if self.model.sparse_templates.cols: raise(NotImplementedError("Sparse template export to ALF not implemented yet")) else: - n_templates, n_wavsamps, nchall = self.model.sparse_templates.data.shape - ncw = min(NCH_WAVEFORMS, nchall) # for some datasets, 32 may be too much + n_templates, n_wavsamps, nchall = templates_v.shape + # for some datasets, 32 may be too much + ncw = min(self.model.n_closest_channels, nchall) assert(n_templates == self.model.n_templates) templates = np.zeros((n_templates, n_wavsamps, ncw), dtype=np.float32) templates_inds = np.zeros((n_templates, ncw), dtype=np.int32) @@ -250,10 +254,10 @@ def make_template_object(self): self.model.channel_positions[self.model.templates_channels[t]]), axis=1) channel_distance[self.model.channel_probes != current_probe] += np.inf templates_inds[t, :] = np.argsort(channel_distance)[:ncw] - templates[t, ...] = self.model.sparse_templates.data[t, :][:, templates_inds[t, :]] - np.save(self.out_path.joinpath('templates.waveforms'), templates * self.ampfactor) + templates[t, ...] = templates_v[t, :][:, templates_inds[t, :]] + np.save(self.out_path.joinpath('templates.waveforms'), templates) np.save(self.out_path.joinpath('templates.waveformsChannels'), templates_inds) - np.save(self.out_path.joinpath('clusters.waveforms'), templates * self.ampfactor) + np.save(self.out_path.joinpath('clusters.waveforms'), templates) np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds) def rename_with_label(self): diff --git a/phylib/io/array.py b/phylib/io/array.py index 29a8d3d..4ad9829 100644 --- a/phylib/io/array.py +++ b/phylib/io/array.py @@ -115,7 +115,7 @@ def _index_of(arr, lookup): # values lookup = np.asarray(lookup, dtype=np.int32) m = (lookup.max() if len(lookup) else 0) + 1 - tmp = np.zeros(m + 1, dtype=np.int) + tmp = np.zeros(m + 1, dtype=int) # Ensure that -1 values are kept. tmp[-1] = -1 if len(lookup): @@ -327,7 +327,7 @@ def get_excerpts(data, n_excerpts=None, excerpt_size=None): def _spikes_in_clusters(spike_clusters, clusters): """Return the ids of all spikes belonging to the specified clusters.""" if len(spike_clusters) == 0 or len(clusters) == 0: - return np.array([], dtype=np.int) + return np.array([], dtype=int) return np.nonzero(np.in1d(spike_clusters, clusters))[0] diff --git a/phylib/io/model.py b/phylib/io/model.py index 3fdfdcf..4d4c116 100644 --- a/phylib/io/model.py +++ b/phylib/io/model.py @@ -1035,44 +1035,70 @@ def get_template_features(self, spike_ids): def get_depths(self): """Compute spike depths based on spike pc features and probe depths.""" # compute the depth as the weighted sum of coordinates - batch_sz = 50000 # number of spikes per batch + # if PC features are provided, compute the depth as the weighted sum of coordinates + nbatch = 50000 c = 0 - spike_depths = np.zeros_like(self.spike_times) - nspi = spike_depths.shape[0] + spikes_depths = np.zeros_like(self.spike_times) * np.nan + nspi = spikes_depths.shape[0] if self.sparse_features is None or self.sparse_features.data.shape[0] != self.n_spikes: return None while True: - ispi = np.arange(c, min(c + batch_sz, nspi)) + ispi = np.arange(c, min(c + nbatch, nspi)) # take only first component - features = np.square(self.sparse_features.data[ispi, :, 0]) - ichannels = self.sparse_features.cols[self.spike_clusters[ispi]].astype(np.int64) + features = self.sparse_features.data[ispi, :, 0] + features = np.maximum(features, 0) ** 2 # takes only positive values into account + ichannels = self.sparse_features.cols[self.spike_clusters[ispi]].astype(np.uint32) ypos = self.channel_positions[ichannels, 1] - - spike_depths[ispi] = np.sum(np.transpose(ypos * features) / - np.sum(features, axis=1), axis=0) - c += batch_sz + with np.errstate(divide='ignore'): + spikes_depths[ispi] = (np.sum(np.transpose(ypos * features) / + np.sum(features, axis=1), axis=0)) + c += nbatch if c >= nspi: break + return spikes_depths - return spike_depths - - def get_amplitudes_true(self): + def get_amplitudes_true(self, sample2unit=1.): """Convert spike amplitude values to input amplitudes units - via scaling by unwhitened template waveform.""" - # unwhiten template waveforms on their channels of max amplitude - templates_chs = self.templates_channels - templates_wfs = self.sparse_templates.data[np.arange(self.n_templates), :, templates_chs] - templates_wfs_unw = templates_wfs.T * self.wmi[templates_chs, templates_chs] - templates_amps = np.abs( - np.max(templates_wfs_unw, axis=0) - np.min(templates_wfs_unw, axis=0)) + via scaling by unwhitened template waveform. + :param sample2unit float: factor to convert the raw data to a physical unit (defaults 1.) + :returns: spike_amplitudes_volts: np.array [nspikes] spike amplitudes in raw data units + :returns: templates_volts: np.array[ntemplates, nsamples, nchannels]: templates + in raw data units + :returns: template_amps_volts: np.array[ntemplates]: average templates amplitudes + in raw data units + To scale the template for template matching, + raw_data_volts = templates_volts * spike_amplitudes_volts / template_amps_volts + """ + # spike_amp = ks2_spike_amps * maxmin(inv_whitening(ks2_template_amps)) + # to rescale the template, - # scale the spike amplitude values by the template amplitude values - amplitudes_v = np.zeros_like(self.amplitudes) - for t in range(self.n_templates): - idxs = self.get_template_spikes(t) - amplitudes_v[idxs] = self.amplitudes[idxs] * templates_amps[t] - - return amplitudes_v + # unwhiten template waveforms on their channels of max amplitude + if self.sparse_templates.cols: + raise NotImplementedError + # apply the inverse whitening matrix to the template + templates_wfs = np.zeros_like(self.sparse_templates.data) # nt, ns, nc + for n in np.arange(self.n_templates): + templates_wfs[n, :, :] = np.matmul(self.sparse_templates.data[n, :, :], self.wmi) + + # The amplitude on each channel is the positive peak minus the negative + templates_ch_amps = np.max(templates_wfs, axis=1) - np.min(templates_wfs, axis=1) + + # The template arbitrary unit amplitude is the amplitude of its largest channel + # (but see below for true tempAmps) + templates_amps_au = np.max(templates_ch_amps, axis=1) + spike_amps = templates_amps_au[self.spike_templates] * self.amplitudes + + with np.errstate(divide='ignore'): + # take the average spike amplitude per template + templates_amps_v = (np.bincount(self.spike_templates, weights=spike_amps) / + np.bincount(self.spike_templates)) + # scale back the template according to the spikes units + templates_physical_unit = templates_wfs * (templates_amps_v / templates_amps_au + )[:, np.newaxis, np.newaxis] + + return (spike_amps * sample2unit, + templates_physical_unit * sample2unit, + templates_amps_v * sample2unit) #-------------------------------------------------------------------------- # Internal helper methods for public high-level methods @@ -1232,7 +1258,8 @@ def save_spike_clusters(self, spike_clusters): logger.debug("Save spike clusters to `%s`.", path) np.save(path, spike_clusters) - def save_spikes_subset_waveforms(self, max_n_spikes_per_template=None, max_n_channels=None): + def save_spikes_subset_waveforms(self, max_n_spikes_per_template=None, max_n_channels=None, + sample2unit=1.): if self.traces is None: logger.warning( "Spike waveforms could not be extracted as the raw data file is not available.") @@ -1240,7 +1267,9 @@ def save_spikes_subset_waveforms(self, max_n_spikes_per_template=None, max_n_cha n_chunks_kept = 20 # TODO: better choice nst = max_n_spikes_per_template - nc = max_n_channels + nc = max_n_channels or self.n_closest_channels + nc = max(nc, self.n_closest_channels) + assert nst > 0 assert nc > 0 @@ -1275,7 +1304,7 @@ def save_spikes_subset_waveforms(self, max_n_spikes_per_template=None, max_n_cha # Extract waveforms from the raw data on a chunk by chunk basis. export_waveforms( path, self.traces, self.spike_samples[spike_ids], spike_channels, - n_samples_waveforms=self.n_samples_waveforms) + n_samples_waveforms=self.n_samples_waveforms, sample2unit=sample2unit) # Reload spike waveforms. self.spike_waveforms = self._load_spike_waveforms() diff --git a/phylib/io/tests/test_alf.py b/phylib/io/tests/test_alf.py index 0eefcd6..6384291 100644 --- a/phylib/io/tests/test_alf.py +++ b/phylib/io/tests/test_alf.py @@ -36,7 +36,7 @@ def __init__(self, tempdir): self.nt = 5 self.ncd = 1000 np.save(p / 'spike_times.npy', .01 * np.cumsum(nr.exponential(size=self.ns))) - np.save(p / 'spike_clusters.npy', nr.randint(low=0, high=self.nt, size=self.ns)) + np.save(p / 'spike_clusters.npy', nr.randint(low=1, high=self.nt, size=self.ns)) shutil.copy(p / 'spike_clusters.npy', p / 'spike_templates.npy') np.save(p / 'amplitudes.npy', nr.uniform(low=0.5, high=1.5, size=self.ns)) np.save(p / 'channel_positions.npy', np.c_[np.arange(self.nc), np.zeros(self.nc)]) @@ -174,16 +174,22 @@ def check_conversion_output(): assert f.exists() # makes sure the output dimensions match (especially clusters which should be 4) - cl_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('clusters.') and - f.name.endswith('.npy')] + cl_shape = [] + for f in new_files: + if f.name.startswith('clusters.') and f.name.endswith('.npy'): + cl_shape.append(np.load(f).shape[0]) + elif f.name.startswith('clusters.') and f.name.endswith('.csv'): + with open(f) as fid: + cl_shape.append(len(fid.readlines()) - 1) sp_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('spikes.')] ch_shape = [np.load(f).shape[0] for f in new_files if f.name.startswith('channels.')] + assert len(set(cl_shape)) == 1 assert len(set(sp_shape)) == 1 assert len(set(ch_shape)) == 1 dur = np.load(next(out_path.glob('clusters.peakToTrough*.npy'))) - assert np.all(dur == np.array([18., -1., 9.5, 2.5, -2.])) + assert np.all(dur == np.array([-9.5, 3., 13., -4.5, -2.5])) def read_after_write(): model = TemplateModel(dir_path=out_path, dat_path=dataset.dat_path, diff --git a/phylib/io/traces.py b/phylib/io/traces.py index dd11267..f337b1c 100644 --- a/phylib/io/traces.py +++ b/phylib/io/traces.py @@ -641,20 +641,21 @@ def iter_waveforms(traces, spike_samples, spike_channels, n_samples_waveforms=No def export_waveforms( - path, traces, spike_samples, spike_channels, n_samples_waveforms=None, cache=False): + path, traces, spike_samples, spike_channels, n_samples_waveforms=None, cache=False, + sample2unit=1): """Export a selection of spike waveforms to a npy file by iterating over the data on a chunk by chunk basis.""" n_spikes = len(spike_samples) spike_channels = np.asarray(spike_channels, dtype=np.int32) n_channels_loc = spike_channels.shape[1] shape = (n_spikes, n_samples_waveforms, n_channels_loc) - - writer = NpyWriter(path, shape, traces.dtype) + dtype = traces.dtype if sample2unit is None else float + writer = NpyWriter(path, shape, dtype) size_written = 0 for waveforms in iter_waveforms( traces, spike_samples, spike_channels, n_samples_waveforms=n_samples_waveforms, cache=cache): - writer.append(waveforms) + writer.append(waveforms * sample2unit) size_written += waveforms.size writer.close() assert prod(shape) == size_written diff --git a/phylib/stats/ccg.py b/phylib/stats/ccg.py index 2c4a912..08febc7 100644 --- a/phylib/stats/ccg.py +++ b/phylib/stats/ccg.py @@ -148,7 +148,7 @@ def correlograms( # At a given shift, the mask precises which spikes have matching spikes # within the correlogram time window. - mask = np.ones_like(spike_samples, dtype=np.bool) + mask = np.ones_like(spike_samples, dtype=bool) correlograms = _create_correlograms_array(n_clusters, winsize_bins) diff --git a/phylib/utils/_types.py b/phylib/utils/_types.py index ee28962..3881d6b 100644 --- a/phylib/utils/_types.py +++ b/phylib/utils/_types.py @@ -15,8 +15,8 @@ #------------------------------------------------------------------------------ _ACCEPTED_ARRAY_DTYPES = ( - np.float, np.float32, np.float64, np.int, np.int8, np.int16, np.uint8, np.uint16, - np.int32, np.int64, np.uint32, np.uint64, np.bool) + float, np.float32, np.float64, int, np.int8, np.int16, np.uint8, np.uint16, + np.int32, np.int64, np.uint32, np.uint64, bool) class Bunch(dict): diff --git a/phylib/utils/tests/test_types.py b/phylib/utils/tests/test_types.py index 35c0273..4bf9a17 100644 --- a/phylib/utils/tests/test_types.py +++ b/phylib/utils/tests/test_types.py @@ -88,13 +88,13 @@ def _check(arr): _check(_as_array(3.)) _check(_as_array([3])) - _check(_as_array(3, np.float)) - _check(_as_array(3., np.float)) - _check(_as_array([3], np.float)) + _check(_as_array(3, float)) + _check(_as_array(3., float)) + _check(_as_array([3], float)) _check(_as_array(np.array([3]))) with raises(ValueError): - _check(_as_array(np.array([3]), dtype=np.object)) - _check(_as_array(np.array([3]), np.float)) + _check(_as_array(np.array([3]), dtype=object)) + _check(_as_array(np.array([3]), float)) assert _as_array(None) is None assert not _is_array_like(None)