From 022436fd63728f2b4536c48dde5c691f0b1f543d Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Thu, 21 Jul 2022 12:20:56 +0100 Subject: [PATCH 1/4] alf conversion after merge --- phylib/io/alf.py | 41 +++++++++++--- phylib/io/model.py | 131 ++++++++++++++++++++++++++++++++++++--------- 2 files changed, 140 insertions(+), 32 deletions(-) diff --git a/phylib/io/alf.py b/phylib/io/alf.py index 42c9c7f..dfa6f2c 100644 --- a/phylib/io/alf.py +++ b/phylib/io/alf.py @@ -174,16 +174,19 @@ def make_cluster_objects(self): """Create clusters.channels, clusters.waveformsDuration and clusters.amps""" peak_channel_path = self.dir_path / 'clusters.channels.npy' if not peak_channel_path.exists(): - self._save_npy(peak_channel_path.name, self.model.templates_channels) + # self._save_npy(peak_channel_path.name, self.model.templates_channels) + self._save_npy(peak_channel_path.name, self.model.clusters_channels) waveform_duration_path = self.dir_path / 'clusters.peakToTrough.npy' if not waveform_duration_path.exists(): - self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations) + # self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations) + self._save_npy(waveform_duration_path.name, self.model.clusters_waveforms_durations) # group by average over cluster number - camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan - camps[self.cluster_ids] = self.model.templates_amplitudes - amps_path = self.dir_path / 'clusters.amps.npy' + # camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan + camps = np.zeros(self.model.clusters_channels.shape[0], ) * np.nan + camps[self.cluster_ids] = self.model.clusters_amplitudes + amps_path = self.dir_path / 'clusters.amps.npy' # TODO these amplitudes are not on the same scale as the spike amps problem? self._save_npy(amps_path.name, camps * self.ampfactor) # clusters uuids @@ -233,7 +236,7 @@ def make_template_and_spikes_objects(self): # and not seconds self._save_npy('spikes.times.npy', self.model.spike_times) self._save_npy('spikes.samples.npy', self.model.spike_samples) - spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor) + spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor, use='templates') self._save_npy('spikes.amps.npy', spike_amps) self._save_npy('templates.amps.npy', template_amps) @@ -257,9 +260,35 @@ def make_template_and_spikes_objects(self): templates[t, ...] = templates_v[t, :][:, templates_inds[t, :]] np.save(self.out_path.joinpath('templates.waveforms'), templates) np.save(self.out_path.joinpath('templates.waveformsChannels'), templates_inds) + + _, clusters_v, cluster_amps = self.model.get_amplitudes_true(self.ampfactor, use='clusters') + n_clusters, n_wavsamps, nchall = clusters_v.shape + # for some datasets, 32 may be too much + ncw = min(self.model.n_closest_channels, nchall) + assert(n_clusters == self.model.n_clusters) + templates = np.zeros((n_clusters, n_wavsamps, ncw), dtype=np.float32) + templates_inds = np.zeros((n_clusters, ncw), dtype=np.int32) + # for each template, find the nearest channels to keep (one the same probe...) + for t in np.arange(n_clusters): + # here we need to fill with nans if it doesn't exists, but then can no longet be int (sorry) # or have it all 0 + channels = self.model.clusters_channels + + current_probe = self.model.channel_probes[channels[t]] + channel_distance = np.sum(np.abs( + self.model.channel_positions - + self.model.channel_positions[channels[t]]), axis=1) + channel_distance[self.model.channel_probes != current_probe] += np.inf + templates_inds[t, :] = np.argsort(channel_distance)[:ncw] + templates[t, ...] = clusters_v[t, :][:, templates_inds[t, :]] np.save(self.out_path.joinpath('clusters.waveforms'), templates) np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds) + # This should really be here + np.save(self.out_path.joinpath('clusters.amps'), cluster_amps) + + # np.save(self.out_path.joinpath('clusters.waveforms'), templates) + # np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds) + def rename_with_label(self): """add the label as an ALF part name before the extension if any label provided""" if not self.label: diff --git a/phylib/io/model.py b/phylib/io/model.py index ad008f5..e671e23 100644 --- a/phylib/io/model.py +++ b/phylib/io/model.py @@ -412,6 +412,18 @@ def _load_data(self): self.n_samples_waveforms = 0 self.n_channels_loc = 0 + # Clusters waveforms + if np.all(self.spike_clusters == self.spike_templates): + self.merge_map = {} + self.nan_clusters = [] + self.sparse_clusters = self.sparse_templates + self.n_clusters = self.spike_templates.max() + 1 + else: + if self.sparse_templates.cols is None: + self.merge_map, self.nan_clusters = self.get_merge_map() + self.sparse_clusters = self.cluster_waveforms() + self.n_clusters = self.spike_clusters.max() + 1 + # Spike waveforms (optional, otherwise fetched from raw data as needed). self.spike_waveforms = self._load_spike_waveforms() @@ -861,12 +873,12 @@ def _template_n_channels(self, template_id, n_channels): channel_ids += [-1] * (n_channels - len(channel_ids)) return channel_ids - def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold=None): + def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold=None, unwhiten=True): """Return data for one template.""" if not self.sparse_templates: return template_w = self.sparse_templates.data[template_id, ...] - template = self._unwhiten(template_w).astype(np.float32) + template = self._unwhiten(template_w).astype(np.float32) if unwhiten else template_w assert template.ndim == 2 channel_ids_, amplitude, best_channel = self._find_best_channels( template, amplitude_threshold=amplitude_threshold) @@ -881,7 +893,7 @@ def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold channel_ids=channel_ids, ) - def _get_template_sparse(self, template_id): + def _get_template_sparse(self, template_id, unwhiten=True): data, cols = self.sparse_templates.data, self.sparse_templates.cols assert cols is not None template_w, channel_ids = data[template_id], cols[template_id] @@ -902,7 +914,7 @@ def _get_template_sparse(self, template_id): channel_ids = channel_ids.astype(np.uint32) # Unwhiten. - template = self._unwhiten(template_w, channel_ids=channel_ids) + template = self._unwhiten(template_w, channel_ids=channel_ids) if unwhiten else template_w template = template.astype(np.float32) assert template.ndim == 2 assert template.shape[1] == len(channel_ids) @@ -920,17 +932,31 @@ def _get_template_sparse(self, template_id): ) return out + def get_merge_map(self): + """"Gets the merge mapping for between spikes.clusters and spikes.templates""" + inverse_mapping_dict = {key: [] for key in range(np.max(self.spike_clusters) + 1)} + for temp in np.unique(self.spike_templates): + idx = np.where(self.spike_templates == temp)[0] + new_idx = self.spike_clusters[idx] + mapping = np.unique(new_idx) + for n in mapping: + inverse_mapping_dict[n].append(temp) + + nan_idx = np.array([idx for idx, val in inverse_mapping_dict.items() if len(val) == 0]) + + return inverse_mapping_dict, nan_idx + #-------------------------------------------------------------------------- # Data access methods #-------------------------------------------------------------------------- - def get_template(self, template_id, channel_ids=None, amplitude_threshold=None): + def get_template(self, template_id, channel_ids=None, amplitude_threshold=None, unwhiten=True): """Get data about a template.""" if self.sparse_templates and self.sparse_templates.cols is not None: - return self._get_template_sparse(template_id) + return self._get_template_sparse(template_id, unwhiten=unwhiten) else: return self._get_template_dense( - template_id, channel_ids=channel_ids, amplitude_threshold=amplitude_threshold) + template_id, channel_ids=channel_ids, amplitude_threshold=amplitude_threshold, unwhiten=unwhiten) def get_waveforms(self, spike_ids, channel_ids=None): """Return spike waveforms on specified channels.""" @@ -1047,7 +1073,7 @@ def get_depths(self): # take only first component features = self.sparse_features.data[ispi, :, 0] features = np.maximum(features, 0) ** 2 # takes only positive values into account - ichannels = self.sparse_features.cols[self.spike_clusters[ispi]].astype(np.uint32) + ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.uint32) ## TODO this should be templates, otherwise won't work # features = np.square(self.sparse_features.data[ispi, :, 0]) # ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.int64) ypos = self.channel_positions[ichannels, 1] @@ -1059,7 +1085,7 @@ def get_depths(self): break return spikes_depths - def get_amplitudes_true(self, sample2unit=1.): + def get_amplitudes_true(self, sample2unit=1., use='templates'): """Convert spike amplitude values to input amplitudes units via scaling by unwhitened template waveform. :param sample2unit float: factor to convert the raw data to a physical unit (defaults 1.) @@ -1074,13 +1100,22 @@ def get_amplitudes_true(self, sample2unit=1.): # spike_amp = ks2_spike_amps * maxmin(inv_whitening(ks2_template_amps)) # to rescale the template, + if use == 'clusters': + sparse = self.sparse_clusters + spikes = self.spike_clusters + n_wav = self.n_clusters + else: + sparse = self.sparse_templates + spikes = self.spike_templates + n_wav = self.n_templates + # unwhiten template waveforms on their channels of max amplitude - if self.sparse_templates.cols: + if sparse.cols: raise NotImplementedError # apply the inverse whitening matrix to the template - templates_wfs = np.zeros_like(self.sparse_templates.data) # nt, ns, nc - for n in np.arange(self.n_templates): - templates_wfs[n, :, :] = np.matmul(self.sparse_templates.data[n, :, :], self.wmi) + templates_wfs = np.zeros_like(sparse.data) # nt, ns, nc + for n in np.arange(n_wav): + templates_wfs[n, :, :] = np.matmul(sparse.data[n, :, :], self.wmi) # The amplitude on each channel is the positive peak minus the negative templates_ch_amps = np.max(templates_wfs, axis=1) - np.min(templates_wfs, axis=1) @@ -1088,12 +1123,12 @@ def get_amplitudes_true(self, sample2unit=1.): # The template arbitrary unit amplitude is the amplitude of its largest channel # (but see below for true tempAmps) templates_amps_au = np.max(templates_ch_amps, axis=1) - spike_amps = templates_amps_au[self.spike_templates] * self.amplitudes + spike_amps = templates_amps_au[spikes] * self.amplitudes with np.errstate(divide='ignore', invalid='ignore'): # take the average spike amplitude per template - templates_amps_v = (np.bincount(self.spike_templates, weights=spike_amps) / - np.bincount(self.spike_templates)) + templates_amps_v = (np.bincount(spikes, weights=spike_amps) / + np.bincount(spikes)) # scale back the template according to the spikes units templates_physical_unit = templates_wfs * (templates_amps_v / templates_amps_au )[:, np.newaxis, np.newaxis] @@ -1167,7 +1202,7 @@ def get_template_waveforms(self, template_id): template = self.get_template(template_id) return template.template if template else None - def get_cluster_mean_waveforms(self, cluster_id): + def get_cluster_mean_waveforms(self, cluster_id, unwhiten=True): """Return the mean template waveforms of a cluster, as a weighted average of the template waveforms from which the cluster originates from.""" count = self.get_template_counts(cluster_id) @@ -1175,10 +1210,10 @@ def get_cluster_mean_waveforms(self, cluster_id): template_ids = np.nonzero(count)[0] count = count[template_ids] # Get local channels of the best template for the given cluster. - template = self.get_template(best_template) + template = self.get_template(best_template, unwhiten=unwhiten) channel_ids = template.channel_ids # Get all templates from which this cluster stems from. - templates = [self.get_template(template_id) for template_id in template_ids] + templates = [self.get_template(template_id, unwhiten=unwhiten) for template_id in template_ids] # Construct the waveforms array. ns = self.n_samples_waveforms data = np.zeros((len(template_ids), ns, self.n_channels)) @@ -1205,16 +1240,27 @@ def get_cluster_spike_waveforms(self, cluster_id): @property def templates_channels(self): """Returns a vector of peak channels for all templates""" - tmp = self.sparse_templates.data + return self._channels(self.sparse_templates) + + @property + def clusters_channels(self): + """Returns a vector of peak channels for all templates""" + channels = self._channels(self.sparse_clusters) + return channels + + def _channels(self, sparse): + # TODO document and better name + tmp = sparse.data n_templates, n_samples, n_channels = tmp.shape - if self.sparse_templates.cols is None: + if sparse.cols is None: template_peak_channels = np.argmax(tmp.max(axis=1) - tmp.min(axis=1), axis=1) else: # when the templates are sparse, the first channel is the highest amplitude channel - template_peak_channels = self.sparse_templates.cols[:, 0] + template_peak_channels = sparse.cols[:, 0] assert template_peak_channels.shape == (n_templates,) return template_peak_channels + @property def templates_probes(self): """Returns a vector of probe index for all templates""" @@ -1223,16 +1269,32 @@ def templates_probes(self): @property def templates_amplitudes(self): """Returns the average amplitude per cluster""" - tid = np.unique(self.spike_templates) - n = np.bincount(self.spike_templates)[tid] - a = np.bincount(self.spike_templates, weights=self.amplitudes)[tid] + return self._amplitudes(self.spike_templates) + + @property + def clusters_amplitudes(self): + """Returns the average amplitude per cluster""" + return self._amplitudes(self.spike_clusters) + + def _amplitudes(self, tmp): + tid = np.unique(tmp) + n = np.bincount(tmp)[tid] + a = np.bincount(tmp, weights=self.amplitudes)[tid] n[np.isnan(n)] = 1 return a / n @property def templates_waveforms_durations(self): """Returns a vector of waveform durations (ms) for all templates""" - tmp = self.sparse_templates.data + return self._waveform_durations(self.sparse_templates.data) + + @property + def clusters_waveforms_durations(self): + """Returns a vector of waveform durations (ms) for all templates""" + waveform_duration = self._waveform_durations(self.sparse_clusters.data) + return waveform_duration + + def _waveform_durations(self, tmp): n_templates, n_samples, n_channels = tmp.shape # Compute the peak channels for each template. template_peak_channels = np.argmax(tmp.max(axis=1) - tmp.min(axis=1), axis=1) @@ -1241,6 +1303,23 @@ def templates_waveforms_durations(self): (n_templates, n_channels), mode='raise', order='C') return durations.flatten()[ind].astype(np.float64) / self.sample_rate * 1e3 + def cluster_waveforms(self): + """ + Computes the cluster waveforms for split and merged clusters + :return: + """ + # Only non sparse implementation + ns = self.n_samples_waveforms # TODO put not implemented warning + data = np.zeros((np.max(self.cluster_ids) + 1, ns, self.n_channels)) # TODO can be self.n_clusters + for clust, val in self.merge_map.items(): + if len(val) > 1: + mean_waveform = self.get_cluster_mean_waveforms(clust, unwhiten=False) + data[clust, :, mean_waveform.channel_ids] = np.swapaxes(mean_waveform.mean_waveforms, 0, 1) + elif len(val) == 1: + data[clust, :, :] = self.sparse_templates.data[val[0], :, :] + + return Bunch(data=data, cols=None) + #-------------------------------------------------------------------------- # Saving methods #-------------------------------------------------------------------------- From 1d09d3b9ce151a8b4750aecd266591920e04756f Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Fri, 22 Jul 2022 13:27:32 +0100 Subject: [PATCH 2/4] tests for merged output --- phylib/io/alf.py | 1 - phylib/io/model.py | 12 +++--- phylib/io/tests/conftest.py | 15 +++++++- phylib/io/tests/test_alf.py | 69 +++++++++++++++++++++++++++++++++++ phylib/io/tests/test_model.py | 19 +++++++++- 5 files changed, 106 insertions(+), 10 deletions(-) diff --git a/phylib/io/alf.py b/phylib/io/alf.py index dfa6f2c..e218249 100644 --- a/phylib/io/alf.py +++ b/phylib/io/alf.py @@ -270,7 +270,6 @@ def make_template_and_spikes_objects(self): templates_inds = np.zeros((n_clusters, ncw), dtype=np.int32) # for each template, find the nearest channels to keep (one the same probe...) for t in np.arange(n_clusters): - # here we need to fill with nans if it doesn't exists, but then can no longet be int (sorry) # or have it all 0 channels = self.model.clusters_channels current_probe = self.model.channel_probes[channels[t]] diff --git a/phylib/io/model.py b/phylib/io/model.py index e671e23..4e80f63 100644 --- a/phylib/io/model.py +++ b/phylib/io/model.py @@ -413,16 +413,14 @@ def _load_data(self): self.n_channels_loc = 0 # Clusters waveforms - if np.all(self.spike_clusters == self.spike_templates): + if not np.all(self.spike_clusters == self.spike_templates) and self.sparse_templates.cols is None: + self.merge_map, _ = self.get_merge_map() + self.sparse_clusters = self.cluster_waveforms() + self.n_clusters = self.spike_clusters.max() + 1 + else: self.merge_map = {} - self.nan_clusters = [] self.sparse_clusters = self.sparse_templates self.n_clusters = self.spike_templates.max() + 1 - else: - if self.sparse_templates.cols is None: - self.merge_map, self.nan_clusters = self.get_merge_map() - self.sparse_clusters = self.cluster_waveforms() - self.n_clusters = self.spike_clusters.max() + 1 # Spike waveforms (optional, otherwise fetched from raw data as needed). self.spike_waveforms = self._load_spike_waveforms() diff --git a/phylib/io/tests/conftest.py b/phylib/io/tests/conftest.py index d29b459..7badbf1 100644 --- a/phylib/io/tests/conftest.py +++ b/phylib/io/tests/conftest.py @@ -98,6 +98,19 @@ def _make_dataset(tempdir, param='dense', has_spike_attributes=True): _remove(tempdir / 'whitening_mat_inv.npy') _remove(tempdir / 'sim_binary.dat') + if param == 'merged': + # remove this file to make templates dense + _remove(tempdir / 'template_ind.npy') + clus = np.load(tempdir / 'spike_clusters.npy') + max_clus = np.max(clus) + # merge cluster 0 and 1 + clus[np.bitwise_or(clus == 0, clus == 1)] = max_clus + 1 + # split cluster 9 into two clusters + idx = np.where(clus == 9)[0] + clus[idx[0:3]] = max_clus + 2 + clus[idx[3:]] = max_clus + 3 + np.save(tempdir / 'spike_clusters.npy', clus) + # Spike attributes. if has_spike_attributes: write_array(tempdir / 'spike_fail.npy', np.full(10, np.nan)) # wrong number of spikes @@ -120,7 +133,7 @@ def _make_dataset(tempdir, param='dense', has_spike_attributes=True): return template_path -@fixture(scope='function', params=('dense', 'sparse', 'misc')) +@fixture(scope='function', params=('dense', 'sparse', 'misc', 'merged')) def template_path_full(tempdir, request): return _make_dataset(tempdir, request.param) diff --git a/phylib/io/tests/test_alf.py b/phylib/io/tests/test_alf.py index 6384291..5d9d0b8 100644 --- a/phylib/io/tests/test_alf.py +++ b/phylib/io/tests/test_alf.py @@ -210,3 +210,72 @@ def read_after_write(): c.convert(out_path, label='probe00') check_conversion_output() read_after_write() + + +def test_merger(dataset): + + path = Path(dataset.tmp_dir) + out_path = path / 'alf' + + model = TemplateModel( + dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc) + + c = EphysAlfCreator(model) + c.convert(out_path) + + model.close() + + # path.joinpath('_phy_spikes_subset.channels.npy').unlink() + # path.joinpath('_phy_spikes_subset.waveforms.npy').unlink() + # path.joinpath('_phy_spikes_subset.spikes.npy').unlink() + + out_path_merge = path / 'alf_merge' + spike_clusters = dataset._load('spike_clusters.npy') + clu, n_clu = np.unique(spike_clusters, return_counts=True) + + # merge the first two clusters + merge_clu = clu[0:2] + spike_clusters[np.bitwise_or(spike_clusters == clu[0], spike_clusters == clu[1])] = np.max(clu) + 1 + # split the cluster with the most spikes + split_clu = clu[-1] + idx = np.where(spike_clusters == split_clu)[0] + spike_clusters[idx[0:int(n_clu[-1] / 2)]] = np.max(clu) + 2 + spike_clusters[idx[int(n_clu[-1] / 2):]] = np.max(clu) + 3 + + np.save(path / 'spike_clusters.npy', spike_clusters) + + model = TemplateModel( + dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc) + print(model.merge_map) + c = EphysAlfCreator(model) + c.convert(out_path_merge) + + # Test that the split are the same for the expected datasets + clu_old = np.load(next(out_path.glob('clusters.peakToTrough.npy'))) + clu_new = np.load(next(out_path_merge.glob('clusters.peakToTrough.npy'))) + assert clu_old[split_clu] == clu_new[np.max(clu) + 2] + assert clu_old[split_clu] == clu_new[np.max(clu) + 3] + assert clu_new[split_clu] == 0 + assert clu_new[merge_clu[0]] == 0 + assert clu_new[merge_clu[1]] == 0 + + clu_old = np.load(next(out_path.glob('clusters.channels.npy'))) + clu_new = np.load(next(out_path_merge.glob('clusters.channels.npy'))) + assert clu_old[split_clu] == clu_new[np.max(clu) + 2] + assert clu_old[split_clu] == clu_new[np.max(clu) + 3] + assert clu_new[split_clu] == 0 + assert clu_new[merge_clu[0]] == 0 + assert clu_new[merge_clu[1]] == 0 + + clu_old = np.load(next(out_path.glob('clusters.depths.npy'))) + clu_new = np.load(next(out_path_merge.glob('clusters.depths.npy'))) + assert clu_old[split_clu] == clu_new[np.max(clu) + 2] + assert clu_old[split_clu] == clu_new[np.max(clu) + 3] + assert clu_new[split_clu] == 0 + assert clu_new[merge_clu[0]] == 0 + assert clu_new[merge_clu[1]] == 0 + + clu_old = np.load(next(out_path.glob('clusters.waveformsChannels.npy'))) + clu_new = np.load(next(out_path_merge.glob('clusters.waveformsChannels.npy'))) + assert np.array_equal(clu_old[split_clu], clu_new[np.max(clu) + 2]) + assert np.array_equal(clu_old[split_clu], clu_new[np.max(clu) + 3]) diff --git a/phylib/io/tests/test_model.py b/phylib/io/tests/test_model.py index 278efca..e58c100 100644 --- a/phylib/io/tests/test_model.py +++ b/phylib/io/tests/test_model.py @@ -14,7 +14,8 @@ # from phylib.utils import Bunch from phylib.utils.testing import captured_output -from ..model import from_sparse, load_model +# from ..model import from_sparse, load_model +from phylib.io.model import from_sparse, load_model logger = logging.getLogger(__name__) @@ -113,6 +114,22 @@ def test_model_depth(template_model): assert depths.shape == (template_model.n_spikes,) +def test_model_merge(template_model_full): + m = template_model_full + + # This is the case where we can do the merging + if not np.all(m.spike_templates == m.spike_clusters) and m.sparse_clusters.cols is None: + assert len(m.merge_map) > 0 + assert not np.array_equal(m.sparse_clusters.data, m.sparse_templates.data) + assert m.sparse_clusters.data.shape[0] == m.n_clusters + assert m.sparse_templates.data.shape[0] == m.n_templates + + else: + assert len(m.merge_map) == 0 + assert np.array_equal(m.sparse_clusters.data, m.sparse_templates.data) + assert np.array_equal(m.n_templates, m.n_clusters) + + def test_model_save(template_model_full): m = template_model_full m.save_metadata('test', {1: 1}) From fb1860bc848f57783e032e0b03391b0b93f154f3 Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Fri, 22 Jul 2022 16:00:24 +0100 Subject: [PATCH 3/4] set values to nans --- phylib/io/alf.py | 11 ++++++----- phylib/io/model.py | 21 +++++++++++---------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/phylib/io/alf.py b/phylib/io/alf.py index e218249..169e3bc 100644 --- a/phylib/io/alf.py +++ b/phylib/io/alf.py @@ -180,13 +180,15 @@ def make_cluster_objects(self): waveform_duration_path = self.dir_path / 'clusters.peakToTrough.npy' if not waveform_duration_path.exists(): # self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations) - self._save_npy(waveform_duration_path.name, self.model.clusters_waveforms_durations) + waveform_duration = self.model.clusters_waveforms_durations + waveform_duration[self.model.nan_idx] = np.nan + self._save_npy(waveform_duration_path.name, waveform_duration) # group by average over cluster number # camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan camps = np.zeros(self.model.clusters_channels.shape[0], ) * np.nan camps[self.cluster_ids] = self.model.clusters_amplitudes - amps_path = self.dir_path / 'clusters.amps.npy' # TODO these amplitudes are not on the same scale as the spike amps problem? + amps_path = self.dir_path / 'clusters.amps.npy' self._save_npy(amps_path.name, camps * self.ampfactor) # clusters uuids @@ -219,6 +221,7 @@ def make_depths(self): n_clusters = cluster_channels.shape[0] clusters_depths = channel_positions[cluster_channels, 1] + clusters_depths[self.model.nan_idx] = np.nan assert clusters_depths.shape == (n_clusters,) if self.model.sparse_features is None: @@ -282,11 +285,9 @@ def make_template_and_spikes_objects(self): np.save(self.out_path.joinpath('clusters.waveforms'), templates) np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds) - # This should really be here + # TODO check if we should save this here, will be inconsistent with what we have at the moment np.save(self.out_path.joinpath('clusters.amps'), cluster_amps) - # np.save(self.out_path.joinpath('clusters.waveforms'), templates) - # np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds) def rename_with_label(self): """add the label as an ALF part name before the extension if any label provided""" diff --git a/phylib/io/model.py b/phylib/io/model.py index 4e80f63..9c198f3 100644 --- a/phylib/io/model.py +++ b/phylib/io/model.py @@ -414,11 +414,12 @@ def _load_data(self): # Clusters waveforms if not np.all(self.spike_clusters == self.spike_templates) and self.sparse_templates.cols is None: - self.merge_map, _ = self.get_merge_map() + self.merge_map, self.nan_idx = self.get_merge_map() self.sparse_clusters = self.cluster_waveforms() self.n_clusters = self.spike_clusters.max() + 1 else: self.merge_map = {} + self.nan_idx = [] self.sparse_clusters = self.sparse_templates self.n_clusters = self.spike_templates.max() + 1 @@ -931,7 +932,7 @@ def _get_template_sparse(self, template_id, unwhiten=True): return out def get_merge_map(self): - """"Gets the merge mapping for between spikes.clusters and spikes.templates""" + """"Gets the maps of merges and splits between spikes.clusters and spikes.templates""" inverse_mapping_dict = {key: [] for key in range(np.max(self.spike_clusters) + 1)} for temp in np.unique(self.spike_templates): idx = np.where(self.spike_templates == temp)[0] @@ -1071,7 +1072,7 @@ def get_depths(self): # take only first component features = self.sparse_features.data[ispi, :, 0] features = np.maximum(features, 0) ** 2 # takes only positive values into account - ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.uint32) ## TODO this should be templates, otherwise won't work + ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.uint32) # features = np.square(self.sparse_features.data[ispi, :, 0]) # ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.int64) ypos = self.channel_positions[ichannels, 1] @@ -1237,17 +1238,17 @@ def get_cluster_spike_waveforms(self, cluster_id): @property def templates_channels(self): - """Returns a vector of peak channels for all templates""" + """Returns a vector of peak channels for all templates waveforms""" return self._channels(self.sparse_templates) @property def clusters_channels(self): - """Returns a vector of peak channels for all templates""" + """Returns a vector of peak channels for all clusters waveforms""" channels = self._channels(self.sparse_clusters) return channels def _channels(self, sparse): - # TODO document and better name + """ Gets peak channels for each waveform""" tmp = sparse.data n_templates, n_samples, n_channels = tmp.shape if sparse.cols is None: @@ -1258,7 +1259,6 @@ def _channels(self, sparse): assert template_peak_channels.shape == (n_templates,) return template_peak_channels - @property def templates_probes(self): """Returns a vector of probe index for all templates""" @@ -1275,6 +1275,7 @@ def clusters_amplitudes(self): return self._amplitudes(self.spike_clusters) def _amplitudes(self, tmp): + """ Compute average amplitude for spikes""" tid = np.unique(tmp) n = np.bincount(tmp)[tid] a = np.bincount(tmp, weights=self.amplitudes)[tid] @@ -1288,7 +1289,7 @@ def templates_waveforms_durations(self): @property def clusters_waveforms_durations(self): - """Returns a vector of waveform durations (ms) for all templates""" + """Returns a vector of waveform durations (ms) for all clusters""" waveform_duration = self._waveform_durations(self.sparse_clusters.data) return waveform_duration @@ -1307,8 +1308,8 @@ def cluster_waveforms(self): :return: """ # Only non sparse implementation - ns = self.n_samples_waveforms # TODO put not implemented warning - data = np.zeros((np.max(self.cluster_ids) + 1, ns, self.n_channels)) # TODO can be self.n_clusters + ns = self.n_samples_waveforms + data = np.zeros((np.max(self.cluster_ids) + 1, ns, self.n_channels)) for clust, val in self.merge_map.items(): if len(val) > 1: mean_waveform = self.get_cluster_mean_waveforms(clust, unwhiten=False) From 6db4487065c2cd42cecdb897c4ae5004e943b40b Mon Sep 17 00:00:00 2001 From: Mayo Faulkner Date: Tue, 18 Oct 2022 18:22:09 +0100 Subject: [PATCH 4/4] tests expect nans --- phylib/io/tests/test_alf.py | 16 +++++++++------- phylib/io/tests/test_model.py | 3 +-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/phylib/io/tests/test_alf.py b/phylib/io/tests/test_alf.py index 5d9d0b8..fc075d1 100644 --- a/phylib/io/tests/test_alf.py +++ b/phylib/io/tests/test_alf.py @@ -20,6 +20,7 @@ from ..model import TemplateModel + #------------------------------------------------------------------------------ # Fixture #------------------------------------------------------------------------------ @@ -246,7 +247,7 @@ def test_merger(dataset): model = TemplateModel( dir_path=path, dat_path=dataset.dat_path, sample_rate=2000, n_channels_dat=dataset.nc) - print(model.merge_map) + c = EphysAlfCreator(model) c.convert(out_path_merge) @@ -255,9 +256,10 @@ def test_merger(dataset): clu_new = np.load(next(out_path_merge.glob('clusters.peakToTrough.npy'))) assert clu_old[split_clu] == clu_new[np.max(clu) + 2] assert clu_old[split_clu] == clu_new[np.max(clu) + 3] - assert clu_new[split_clu] == 0 - assert clu_new[merge_clu[0]] == 0 - assert clu_new[merge_clu[1]] == 0 + + assert np.isnan([clu_new[split_clu]])[0] + assert np.isnan([clu_new[merge_clu[0]]])[0] + assert np.isnan([clu_new[merge_clu[1]]])[0] clu_old = np.load(next(out_path.glob('clusters.channels.npy'))) clu_new = np.load(next(out_path_merge.glob('clusters.channels.npy'))) @@ -271,9 +273,9 @@ def test_merger(dataset): clu_new = np.load(next(out_path_merge.glob('clusters.depths.npy'))) assert clu_old[split_clu] == clu_new[np.max(clu) + 2] assert clu_old[split_clu] == clu_new[np.max(clu) + 3] - assert clu_new[split_clu] == 0 - assert clu_new[merge_clu[0]] == 0 - assert clu_new[merge_clu[1]] == 0 + assert np.isnan([clu_new[split_clu]])[0] + assert np.isnan([clu_new[merge_clu[0]]])[0] + assert np.isnan([clu_new[merge_clu[1]]])[0] clu_old = np.load(next(out_path.glob('clusters.waveformsChannels.npy'))) clu_new = np.load(next(out_path_merge.glob('clusters.waveformsChannels.npy'))) diff --git a/phylib/io/tests/test_model.py b/phylib/io/tests/test_model.py index e58c100..f9b32be 100644 --- a/phylib/io/tests/test_model.py +++ b/phylib/io/tests/test_model.py @@ -14,8 +14,7 @@ # from phylib.utils import Bunch from phylib.utils.testing import captured_output -# from ..model import from_sparse, load_model -from phylib.io.model import from_sparse, load_model +from ..model import from_sparse, load_model logger = logging.getLogger(__name__)