Skip to content

Commit

Permalink
TST: Add tests asserting correct calculation of mean rates
Browse files Browse the repository at this point in the history
  • Loading branch information
ntolley committed Sep 26, 2020
1 parent c19f289 commit b08f10f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
7 changes: 7 additions & 0 deletions hnn_core/network.py
Expand Up @@ -484,6 +484,13 @@ def mean_rates(self, tstart, tstop, gid_dict, mean_type='all'):
raise ValueError("Invalid mean_type. Valid arguments include "
"'all', 'trial', or 'cell'.")

# Validate tstart, tstop
if not isinstance(tstart, (int, float)) or not isinstance(
tstop, (int, float)):
raise ValueError('tstart and tstop must be of type int or float')
elif tstop <= tstart:
raise ValueError('tstop must be greater than tstart')

for cell_type in cell_types:
cell_type_gids = np.array(gid_dict[cell_type])
gid_spike_rate = np.zeros((len(self._times), len(cell_type_gids)))
Expand Down
55 changes: 30 additions & 25 deletions hnn_core/tests/test_network.py
Expand Up @@ -62,26 +62,14 @@ def test_spikes():
spiketimes = [[2.3456, 7.89], [4.2812, 93.2]]
spikegids = [[1, 3], [5, 7]]
spiketypes = [['L2_pyramidal', 'L2_basket'], ['L5_pyramidal', 'L5_basket']]
tstart, tstop = 0.1, 98.4
gid_dict = {'L2_pyramidal': range(1, 2), 'L2_basket': range(3, 4),
'L5_pyramidal': range(5, 6), 'L5_basket': range(7, 8)}
spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes)
spikes.plot_hist(show=False)
spikes.write('/tmp/spk_%d.txt')
assert spikes == read_spikes('/tmp/spk_*.txt')
assert ("Spikes | 2 simulation trials" in repr(spikes))
# assert spikes.mean_rates() == {
# 'L5_pyramidal': 5.08646998982706,
# 'L5_basket': 5.08646998982706,
# 'L2_pyramidal': 5.08646998982706,
# 'L2_basket': 5.08646998982706}
# assert spikes.mean_rates(mean_type='trial') == {
# 'L5_pyramidal': [0.0, 10.17293997965412],
# 'L5_basket': [0.0, 10.17293997965412],
# 'L2_pyramidal': [10.17293997965412, 0.0],
# 'L2_basket': [10.17293997965412, 0.0]}
# assert spikes.mean_rates(mean_type='cell') == {
# 'L5_pyramidal': [[0.0], [10.17293997965412]],
# 'L5_basket': [[0.0], [10.17293997965412]],
# 'L2_pyramidal': [[10.17293997965412], [0.0]],
# 'L2_basket': [[10.17293997965412], [0.0]]}

with pytest.raises(TypeError, match="times should be a list of lists"):
spikes = Spikes(times=([2.3456, 7.89], [4.2812, 93.2]), gids=spikegids,
Expand All @@ -95,15 +83,6 @@ def test_spikes():
spikes = Spikes(times=[[2.3456, 7.89]], gids=spikegids,
types=spiketypes)

with pytest.raises(ValueError, match="tstart and tstop must be of type "
"int or float"):
spikes = Spikes()
spikes.mean_rates(tstart=0.1, tstop='ABC', gid_dict={})

with pytest.raises(ValueError, match="tstop must be greater than tstart"):
spikes = Spikes()
spikes.mean_rates(tstart=0.1, tstop=-1.0, gid_dict={})

spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes)

with pytest.raises(TypeError, match="spike_types should be str, "
Expand All @@ -122,9 +101,35 @@ def test_spikes():
with pytest.raises(ValueError, match="No input types found for ABC"):
spikes.plot_hist(spike_types='ABC', show=False)

with pytest.raises(ValueError, match="tstart and tstop must be of type "
"int or float"):
spikes.mean_rates(tstart=0.1, tstop='ABC', gid_dict=gid_dict)

with pytest.raises(ValueError, match="tstop must be greater than tstart"):
spikes.mean_rates(tstart=0.1, tstop=-1.0, gid_dict=gid_dict)

with pytest.raises(ValueError, match="Invalid mean_type. Valid "
"arguments include 'all', 'trial', or 'cell'."):
spikes.mean_rates(tstart=0.1, tstop=98.4, gid_dict={}, mean_type='ABC')
spikes.mean_rates(tstart=tstart, tstop=tstop, gid_dict=gid_dict,
mean_type='ABC')

test_rate = (1 / (tstop - tstart)) * 1000

assert spikes.mean_rates(tstart, tstop, gid_dict) == {
'L5_pyramidal': test_rate / 2,
'L5_basket': test_rate / 2,
'L2_pyramidal': test_rate / 2,
'L2_basket': test_rate / 2}
assert spikes.mean_rates(tstart, tstop, gid_dict, mean_type='trial') == {
'L5_pyramidal': [0.0, test_rate],
'L5_basket': [0.0, test_rate],
'L2_pyramidal': [test_rate, 0.0],
'L2_basket': [test_rate, 0.0]}
assert spikes.mean_rates(tstart, tstop, gid_dict, mean_type='cell') == {
'L5_pyramidal': [[0.0], [test_rate]],
'L5_basket': [[0.0], [test_rate]],
'L2_pyramidal': [[test_rate], [0.0]],
'L2_basket': [[test_rate], [0.0]]}

# Write spike file with no 'types' column
# Check for gid_dict errors
Expand Down

0 comments on commit b08f10f

Please sign in to comment.