Skip to content

Commit

Permalink
Merge branch 'master' of github.com:baccuslab/pyret
Browse files Browse the repository at this point in the history
  • Loading branch information
Niru Maheswaranathan committed Nov 18, 2016
2 parents 390a461 + 0d2d4b2 commit 15df233
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ dist/
htmlcov/
.cache/
.DS_Store
tests/test-images/
40 changes: 17 additions & 23 deletions pyret/spiketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

__all__ = ['binspikes', 'estfr', 'detectevents', 'peakdet', 'SpikingEvent']

Expand Down Expand Up @@ -105,17 +106,10 @@ def __eq__(self, other):
return (self.start == other.start) & (self.stop == other.stop)

def trial_counts(self):
"""
Count the number of spikes per trial
>> counts = spkevent.trial_counts()
"""
counts, _ = np.histogram(self.spikes[:, 1], bins=np.arange(
np.min(self.spikes[:, 1]), np.max(self.spikes[:, 1])))
return counts
"""Count the number of spikes per trial"""
return Counter(self.spikes[:, 1])

def event_stats(self):
def stats(self):
"""
Compute statistics (mean and standard deviation) across spike counts
Expand All @@ -124,7 +118,7 @@ def event_stats(self):
"""

# count number of spikes per trial
counts = self.trial_counts()
counts = list(self.trial_counts().values())

return np.mean(counts), np.std(counts)

Expand Down Expand Up @@ -219,7 +213,7 @@ def detectevents(spk, threshold=(0.3, 0.05)):
threshold : (float, float), optional
A tuple of two floats that are used as thresholds for detecting firing
events. Default: (0.1, 0.005) see `peakdetect.py` for more info
events. Default: (0.1, 0.005) see `peakdet` for more info
Returns
-------
Expand All @@ -228,10 +222,10 @@ def detectevents(spk, threshold=(0.3, 0.05)):
See the `spikingevent` class for more info.
"""
# find peaks in the PSTH
bspk, tax = binspikes(spk[:, 0], binsize=0.01,
num_trials=np.max(spk[:, 1]))
psth = estfr(tax, bspk, sigma=0.02)
maxtab, _ = peakdet(psth, threshold[0], tax)
time = np.arange(0, np.ceil(spk[:, 0].max()), 0.01)
bspk = binspikes(spk[:, 0], time)
psth = estfr(bspk, time, sigma=0.01)
maxtab, _ = peakdet(psth, threshold[0], time)

# store spiking events in a list
events = list()
Expand All @@ -241,21 +235,21 @@ def detectevents(spk, threshold=(0.3, 0.05)):

# get putative start and stop indices of each spiking event
start_indices = np.where((psth <= threshold[1]) &
(tax < maxtab[eventidx, 0]))[0]
(time < maxtab[eventidx, 0]))[0]
stop_indices = np.where((psth <= threshold[1]) &
(tax > maxtab[eventidx, 0]))[0]
(time > maxtab[eventidx, 0]))[0]

# find the start time, defined as the right most peak index
if start_indices.size == 0:
starttime = tax[0]
starttime = time[0]
else:
starttime = tax[np.max(start_indices)]
starttime = time[np.max(start_indices)]

# find the stop time, defined as the lest most peak index
if stop_indices.size == 0:
stoptime = tax[-1]
stoptime = time[-1]
else:
stoptime = tax[np.min(stop_indices)]
stoptime = time[np.min(stop_indices)]

# find spikes within this time interval
event_spikes = spk[(spk[:, 0] >= starttime) &
Expand All @@ -268,7 +262,7 @@ def detectevents(spk, threshold=(0.3, 0.05)):
if not events or not (events[-1] == event):
events.append(event)

return tax, psth, bspk, events
return time, psth, bspk, events


def peakdet(v, delta, x=None):
Expand Down
72 changes: 64 additions & 8 deletions tests/test_spiketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,73 @@ def test_estfr():
assert (fr.sum() * dt) == bspk.sum()


def test_detectevents():
pass
def test_spiking_events():
np.random.seed(1234)

# generate spike times
spiketimes = np.array([0.1, 0.25, 0.5, 0.75, 0.9])
N = len(spiketimes)
T = 50
jitter = 0.01
spikes = []
for trial_index in range(T):
s = spiketimes + jitter * np.random.randn(N,)
spikes.append(np.stack((s, trial_index * np.ones(N,))))
spikes = np.hstack(spikes).T

def test_peakdet():
pass
# detect events
t, psth, bspk, events = spk.detectevents(spikes)

# correct number of events
assert len(events) == N

# test SpikingEvent class
ev = events[0]
assert isinstance(ev, spk.SpikingEvent)

# mean jitter should be close to the selected amount of jitter
mean_jitter = np.mean([e.jitter() for e in events])
assert np.allclose(mean_jitter, jitter, atol=1e-3)

# time to first spike (TTFS) should match the only spike in each trial
assert np.allclose(ev.spikes[:, 0], ev.ttfs())

def test_split_trials():
pass
# one spike per trial
mu, sigma = ev.stats()
assert mu == 1
assert sigma == 0

# test sorting
sorted_spks = ev.sort()
sorted_spks = sorted_spks[np.argsort(sorted_spks[:, 1]), 0]
assert np.all(np.diff(sorted_spks) > 0)


def test_peakdet():

def test_SpikingEvent():
pass
# create a toy signal
u = np.linspace(-5, 5, 1001)
x = np.exp(-u ** 2)
dx = np.gradient(x, 1e-2)

# one peak in x (delta=0.5)
maxtab, mintab = spk.peakdet(x, delta=0.5)
assert len(mintab) == 0
assert len(maxtab) == 1
assert np.allclose(maxtab, np.array([[500, 1]]))

# one peak in x (delta=0.1)
maxtab, mintab = spk.peakdet(x, delta=0.1)
assert len(mintab) == 0
assert len(maxtab) == 1
assert np.allclose(maxtab, np.array([[500, 1]]))

# no peaks in x (delta=1.0)
maxtab, mintab = spk.peakdet(x, delta=1.)
assert len(mintab) == 0
assert len(maxtab) == 0

# one peak and one valley in dx
maxtab, mintab = spk.peakdet(dx, delta=0.2)
assert np.allclose(maxtab, np.array([[429, 0.8576926]]))
assert np.allclose(mintab, np.array([[571, -0.8576926]]))
1 change: 0 additions & 1 deletion tests/test_visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,3 @@ def test_playrates():
filename, 1)
os.remove(filename)
plt.close('all')

0 comments on commit 15df233

Please sign in to comment.