Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions phylib/io/alf.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,20 @@ def make_cluster_objects(self):
"""Create clusters.channels, clusters.waveformsDuration and clusters.amps"""
peak_channel_path = self.dir_path / 'clusters.channels.npy'
if not peak_channel_path.exists():
self._save_npy(peak_channel_path.name, self.model.templates_channels)
# self._save_npy(peak_channel_path.name, self.model.templates_channels)
self._save_npy(peak_channel_path.name, self.model.clusters_channels)

waveform_duration_path = self.dir_path / 'clusters.peakToTrough.npy'
if not waveform_duration_path.exists():
self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations)
# self._save_npy(waveform_duration_path.name, self.model.templates_waveforms_durations)
waveform_duration = self.model.clusters_waveforms_durations
waveform_duration[self.model.nan_idx] = np.nan
self._save_npy(waveform_duration_path.name, waveform_duration)

# group by average over cluster number
camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan
camps[self.cluster_ids] = self.model.templates_amplitudes
# camps = np.zeros(self.model.templates_channels.shape[0],) * np.nan
camps = np.zeros(self.model.clusters_channels.shape[0], ) * np.nan
camps[self.cluster_ids] = self.model.clusters_amplitudes
amps_path = self.dir_path / 'clusters.amps.npy'
self._save_npy(amps_path.name, camps * self.ampfactor)

Expand Down Expand Up @@ -216,6 +221,7 @@ def make_depths(self):
n_clusters = cluster_channels.shape[0]

clusters_depths = channel_positions[cluster_channels, 1]
clusters_depths[self.model.nan_idx] = np.nan
assert clusters_depths.shape == (n_clusters,)

if self.model.sparse_features is None:
Expand All @@ -233,7 +239,7 @@ def make_template_and_spikes_objects(self):
# and not seconds
self._save_npy('spikes.times.npy', self.model.spike_times)
self._save_npy('spikes.samples.npy', self.model.spike_samples)
spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor)
spike_amps, templates_v, template_amps = self.model.get_amplitudes_true(self.ampfactor, use='templates')
self._save_npy('spikes.amps.npy', spike_amps)
self._save_npy('templates.amps.npy', template_amps)

Expand All @@ -257,9 +263,32 @@ def make_template_and_spikes_objects(self):
templates[t, ...] = templates_v[t, :][:, templates_inds[t, :]]
np.save(self.out_path.joinpath('templates.waveforms'), templates)
np.save(self.out_path.joinpath('templates.waveformsChannels'), templates_inds)

_, clusters_v, cluster_amps = self.model.get_amplitudes_true(self.ampfactor, use='clusters')
n_clusters, n_wavsamps, nchall = clusters_v.shape
# for some datasets, 32 may be too much
ncw = min(self.model.n_closest_channels, nchall)
assert(n_clusters == self.model.n_clusters)
templates = np.zeros((n_clusters, n_wavsamps, ncw), dtype=np.float32)
templates_inds = np.zeros((n_clusters, ncw), dtype=np.int32)
# for each template, find the nearest channels to keep (one the same probe...)
for t in np.arange(n_clusters):
channels = self.model.clusters_channels

current_probe = self.model.channel_probes[channels[t]]
channel_distance = np.sum(np.abs(
self.model.channel_positions -
self.model.channel_positions[channels[t]]), axis=1)
channel_distance[self.model.channel_probes != current_probe] += np.inf
templates_inds[t, :] = np.argsort(channel_distance)[:ncw]
templates[t, ...] = clusters_v[t, :][:, templates_inds[t, :]]
np.save(self.out_path.joinpath('clusters.waveforms'), templates)
np.save(self.out_path.joinpath('clusters.waveformsChannels'), templates_inds)

# TODO check if we should save this here, will be inconsistent with what we have at the moment
np.save(self.out_path.joinpath('clusters.amps'), cluster_amps)


def rename_with_label(self):
"""add the label as an ALF part name before the extension if any label provided"""
if not self.label:
Expand Down
132 changes: 105 additions & 27 deletions phylib/io/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,17 @@ def _load_data(self):
self.n_samples_waveforms = 0
self.n_channels_loc = 0

# Clusters waveforms
if not np.all(self.spike_clusters == self.spike_templates) and self.sparse_templates.cols is None:
self.merge_map, self.nan_idx = self.get_merge_map()
self.sparse_clusters = self.cluster_waveforms()
self.n_clusters = self.spike_clusters.max() + 1
else:
self.merge_map = {}
self.nan_idx = []
self.sparse_clusters = self.sparse_templates
self.n_clusters = self.spike_templates.max() + 1

# Spike waveforms (optional, otherwise fetched from raw data as needed).
self.spike_waveforms = self._load_spike_waveforms()

Expand Down Expand Up @@ -861,12 +872,12 @@ def _template_n_channels(self, template_id, n_channels):
channel_ids += [-1] * (n_channels - len(channel_ids))
return channel_ids

def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold=None):
def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold=None, unwhiten=True):
"""Return data for one template."""
if not self.sparse_templates:
return
template_w = self.sparse_templates.data[template_id, ...]
template = self._unwhiten(template_w).astype(np.float32)
template = self._unwhiten(template_w).astype(np.float32) if unwhiten else template_w
assert template.ndim == 2
channel_ids_, amplitude, best_channel = self._find_best_channels(
template, amplitude_threshold=amplitude_threshold)
Expand All @@ -881,7 +892,7 @@ def _get_template_dense(self, template_id, channel_ids=None, amplitude_threshold
channel_ids=channel_ids,
)

def _get_template_sparse(self, template_id):
def _get_template_sparse(self, template_id, unwhiten=True):
data, cols = self.sparse_templates.data, self.sparse_templates.cols
assert cols is not None
template_w, channel_ids = data[template_id], cols[template_id]
Expand All @@ -902,7 +913,7 @@ def _get_template_sparse(self, template_id):
channel_ids = channel_ids.astype(np.uint32)

# Unwhiten.
template = self._unwhiten(template_w, channel_ids=channel_ids)
template = self._unwhiten(template_w, channel_ids=channel_ids) if unwhiten else template_w
template = template.astype(np.float32)
assert template.ndim == 2
assert template.shape[1] == len(channel_ids)
Expand All @@ -920,17 +931,31 @@ def _get_template_sparse(self, template_id):
)
return out

def get_merge_map(self):
""""Gets the maps of merges and splits between spikes.clusters and spikes.templates"""
inverse_mapping_dict = {key: [] for key in range(np.max(self.spike_clusters) + 1)}
for temp in np.unique(self.spike_templates):
idx = np.where(self.spike_templates == temp)[0]
new_idx = self.spike_clusters[idx]
mapping = np.unique(new_idx)
for n in mapping:
inverse_mapping_dict[n].append(temp)

nan_idx = np.array([idx for idx, val in inverse_mapping_dict.items() if len(val) == 0])

return inverse_mapping_dict, nan_idx

#--------------------------------------------------------------------------
# Data access methods
#--------------------------------------------------------------------------

def get_template(self, template_id, channel_ids=None, amplitude_threshold=None):
def get_template(self, template_id, channel_ids=None, amplitude_threshold=None, unwhiten=True):
"""Get data about a template."""
if self.sparse_templates and self.sparse_templates.cols is not None:
return self._get_template_sparse(template_id)
return self._get_template_sparse(template_id, unwhiten=unwhiten)
else:
return self._get_template_dense(
template_id, channel_ids=channel_ids, amplitude_threshold=amplitude_threshold)
template_id, channel_ids=channel_ids, amplitude_threshold=amplitude_threshold, unwhiten=unwhiten)

def get_waveforms(self, spike_ids, channel_ids=None):
"""Return spike waveforms on specified channels."""
Expand Down Expand Up @@ -1047,7 +1072,7 @@ def get_depths(self):
# take only first component
features = self.sparse_features.data[ispi, :, 0]
features = np.maximum(features, 0) ** 2 # takes only positive values into account
ichannels = self.sparse_features.cols[self.spike_clusters[ispi]].astype(np.uint32)
ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.uint32)
# features = np.square(self.sparse_features.data[ispi, :, 0])
# ichannels = self.sparse_features.cols[self.spike_templates[ispi]].astype(np.int64)
ypos = self.channel_positions[ichannels, 1]
Expand All @@ -1059,7 +1084,7 @@ def get_depths(self):
break
return spikes_depths

def get_amplitudes_true(self, sample2unit=1.):
def get_amplitudes_true(self, sample2unit=1., use='templates'):
"""Convert spike amplitude values to input amplitudes units
via scaling by unwhitened template waveform.
:param sample2unit float: factor to convert the raw data to a physical unit (defaults 1.)
Expand All @@ -1074,26 +1099,35 @@ def get_amplitudes_true(self, sample2unit=1.):
# spike_amp = ks2_spike_amps * maxmin(inv_whitening(ks2_template_amps))
# to rescale the template,

if use == 'clusters':
sparse = self.sparse_clusters
spikes = self.spike_clusters
n_wav = self.n_clusters
else:
sparse = self.sparse_templates
spikes = self.spike_templates
n_wav = self.n_templates

# unwhiten template waveforms on their channels of max amplitude
if self.sparse_templates.cols:
if sparse.cols:
raise NotImplementedError
# apply the inverse whitening matrix to the template
templates_wfs = np.zeros_like(self.sparse_templates.data) # nt, ns, nc
for n in np.arange(self.n_templates):
templates_wfs[n, :, :] = np.matmul(self.sparse_templates.data[n, :, :], self.wmi)
templates_wfs = np.zeros_like(sparse.data) # nt, ns, nc
for n in np.arange(n_wav):
templates_wfs[n, :, :] = np.matmul(sparse.data[n, :, :], self.wmi)

# The amplitude on each channel is the positive peak minus the negative
templates_ch_amps = np.max(templates_wfs, axis=1) - np.min(templates_wfs, axis=1)

# The template arbitrary unit amplitude is the amplitude of its largest channel
# (but see below for true tempAmps)
templates_amps_au = np.max(templates_ch_amps, axis=1)
spike_amps = templates_amps_au[self.spike_templates] * self.amplitudes
spike_amps = templates_amps_au[spikes] * self.amplitudes

with np.errstate(divide='ignore', invalid='ignore'):
# take the average spike amplitude per template
templates_amps_v = (np.bincount(self.spike_templates, weights=spike_amps) /
np.bincount(self.spike_templates))
templates_amps_v = (np.bincount(spikes, weights=spike_amps) /
np.bincount(spikes))
# scale back the template according to the spikes units
templates_physical_unit = templates_wfs * (templates_amps_v / templates_amps_au
)[:, np.newaxis, np.newaxis]
Expand Down Expand Up @@ -1167,18 +1201,18 @@ def get_template_waveforms(self, template_id):
template = self.get_template(template_id)
return template.template if template else None

def get_cluster_mean_waveforms(self, cluster_id):
def get_cluster_mean_waveforms(self, cluster_id, unwhiten=True):
"""Return the mean template waveforms of a cluster, as a weighted average of the
template waveforms from which the cluster originates from."""
count = self.get_template_counts(cluster_id)
best_template = np.argmax(count)
template_ids = np.nonzero(count)[0]
count = count[template_ids]
# Get local channels of the best template for the given cluster.
template = self.get_template(best_template)
template = self.get_template(best_template, unwhiten=unwhiten)
channel_ids = template.channel_ids
# Get all templates from which this cluster stems from.
templates = [self.get_template(template_id) for template_id in template_ids]
templates = [self.get_template(template_id, unwhiten=unwhiten) for template_id in template_ids]
# Construct the waveforms array.
ns = self.n_samples_waveforms
data = np.zeros((len(template_ids), ns, self.n_channels))
Expand All @@ -1204,14 +1238,24 @@ def get_cluster_spike_waveforms(self, cluster_id):

@property
def templates_channels(self):
"""Returns a vector of peak channels for all templates"""
tmp = self.sparse_templates.data
"""Returns a vector of peak channels for all templates waveforms"""
return self._channels(self.sparse_templates)

@property
def clusters_channels(self):
"""Returns a vector of peak channels for all clusters waveforms"""
channels = self._channels(self.sparse_clusters)
return channels

def _channels(self, sparse):
""" Gets peak channels for each waveform"""
tmp = sparse.data
n_templates, n_samples, n_channels = tmp.shape
if self.sparse_templates.cols is None:
if sparse.cols is None:
template_peak_channels = np.argmax(tmp.max(axis=1) - tmp.min(axis=1), axis=1)
else:
# when the templates are sparse, the first channel is the highest amplitude channel
template_peak_channels = self.sparse_templates.cols[:, 0]
template_peak_channels = sparse.cols[:, 0]
assert template_peak_channels.shape == (n_templates,)
return template_peak_channels

Expand All @@ -1223,16 +1267,33 @@ def templates_probes(self):
@property
def templates_amplitudes(self):
"""Returns the average amplitude per cluster"""
tid = np.unique(self.spike_templates)
n = np.bincount(self.spike_templates)[tid]
a = np.bincount(self.spike_templates, weights=self.amplitudes)[tid]
return self._amplitudes(self.spike_templates)

@property
def clusters_amplitudes(self):
"""Returns the average amplitude per cluster"""
return self._amplitudes(self.spike_clusters)

def _amplitudes(self, tmp):
""" Compute average amplitude for spikes"""
tid = np.unique(tmp)
n = np.bincount(tmp)[tid]
a = np.bincount(tmp, weights=self.amplitudes)[tid]
n[np.isnan(n)] = 1
return a / n

@property
def templates_waveforms_durations(self):
"""Returns a vector of waveform durations (ms) for all templates"""
tmp = self.sparse_templates.data
return self._waveform_durations(self.sparse_templates.data)

@property
def clusters_waveforms_durations(self):
"""Returns a vector of waveform durations (ms) for all clusters"""
waveform_duration = self._waveform_durations(self.sparse_clusters.data)
return waveform_duration

def _waveform_durations(self, tmp):
n_templates, n_samples, n_channels = tmp.shape
# Compute the peak channels for each template.
template_peak_channels = np.argmax(tmp.max(axis=1) - tmp.min(axis=1), axis=1)
Expand All @@ -1241,6 +1302,23 @@ def templates_waveforms_durations(self):
(n_templates, n_channels), mode='raise', order='C')
return durations.flatten()[ind].astype(np.float64) / self.sample_rate * 1e3

def cluster_waveforms(self):
"""
Computes the cluster waveforms for split and merged clusters
:return:
"""
# Only non sparse implementation
ns = self.n_samples_waveforms
data = np.zeros((np.max(self.cluster_ids) + 1, ns, self.n_channels))
for clust, val in self.merge_map.items():
if len(val) > 1:
mean_waveform = self.get_cluster_mean_waveforms(clust, unwhiten=False)
data[clust, :, mean_waveform.channel_ids] = np.swapaxes(mean_waveform.mean_waveforms, 0, 1)
elif len(val) == 1:
data[clust, :, :] = self.sparse_templates.data[val[0], :, :]

return Bunch(data=data, cols=None)

#--------------------------------------------------------------------------
# Saving methods
#--------------------------------------------------------------------------
Expand Down
15 changes: 14 additions & 1 deletion phylib/io/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,19 @@ def _make_dataset(tempdir, param='dense', has_spike_attributes=True):
_remove(tempdir / 'whitening_mat_inv.npy')
_remove(tempdir / 'sim_binary.dat')

if param == 'merged':
# remove this file to make templates dense
_remove(tempdir / 'template_ind.npy')
clus = np.load(tempdir / 'spike_clusters.npy')
max_clus = np.max(clus)
# merge cluster 0 and 1
clus[np.bitwise_or(clus == 0, clus == 1)] = max_clus + 1
# split cluster 9 into two clusters
idx = np.where(clus == 9)[0]
clus[idx[0:3]] = max_clus + 2
clus[idx[3:]] = max_clus + 3
np.save(tempdir / 'spike_clusters.npy', clus)

# Spike attributes.
if has_spike_attributes:
write_array(tempdir / 'spike_fail.npy', np.full(10, np.nan)) # wrong number of spikes
Expand All @@ -120,7 +133,7 @@ def _make_dataset(tempdir, param='dense', has_spike_attributes=True):
return template_path


@fixture(scope='function', params=('dense', 'sparse', 'misc'))
@fixture(scope='function', params=('dense', 'sparse', 'misc', 'merged'))
def template_path_full(tempdir, request):
return _make_dataset(tempdir, request.param)

Expand Down
Loading