From b08f10f3d42c75da8fcbeb6353b0dd1beaec3a1e Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Fri, 25 Sep 2020 21:15:42 -0700 Subject: [PATCH] TST: Add tests asserting correct calculation of mean rates --- hnn_core/network.py | 7 +++++ hnn_core/tests/test_network.py | 55 ++++++++++++++++++---------------- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index 54aa0988a..e46f7c227 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -484,6 +484,13 @@ def mean_rates(self, tstart, tstop, gid_dict, mean_type='all'): raise ValueError("Invalid mean_type. Valid arguments include " "'all', 'trial', or 'cell'.") + # Validate tstart, tstop + if not isinstance(tstart, (int, float)) or not isinstance( + tstop, (int, float)): + raise ValueError('tstart and tstop must be of type int or float') + elif tstop <= tstart: + raise ValueError('tstop must be greater than tstart') + for cell_type in cell_types: cell_type_gids = np.array(gid_dict[cell_type]) gid_spike_rate = np.zeros((len(self._times), len(cell_type_gids))) diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 512a769cf..b18b69ab5 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -62,26 +62,14 @@ def test_spikes(): spiketimes = [[2.3456, 7.89], [4.2812, 93.2]] spikegids = [[1, 3], [5, 7]] spiketypes = [['L2_pyramidal', 'L2_basket'], ['L5_pyramidal', 'L5_basket']] + tstart, tstop = 0.1, 98.4 + gid_dict = {'L2_pyramidal': range(1, 2), 'L2_basket': range(3, 4), + 'L5_pyramidal': range(5, 6), 'L5_basket': range(7, 8)} spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) spikes.plot_hist(show=False) spikes.write('/tmp/spk_%d.txt') assert spikes == read_spikes('/tmp/spk_*.txt') assert ("Spikes | 2 simulation trials" in repr(spikes)) - # assert spikes.mean_rates() == { - # 'L5_pyramidal': 5.08646998982706, - # 'L5_basket': 5.08646998982706, - # 'L2_pyramidal': 5.08646998982706, - # 'L2_basket': 5.08646998982706} - # assert spikes.mean_rates(mean_type='trial') == { - # 'L5_pyramidal': [0.0, 10.17293997965412], - # 'L5_basket': [0.0, 10.17293997965412], - # 'L2_pyramidal': [10.17293997965412, 0.0], - # 'L2_basket': [10.17293997965412, 0.0]} - # assert spikes.mean_rates(mean_type='cell') == { - # 'L5_pyramidal': [[0.0], [10.17293997965412]], - # 'L5_basket': [[0.0], [10.17293997965412]], - # 'L2_pyramidal': [[10.17293997965412], [0.0]], - # 'L2_basket': [[10.17293997965412], [0.0]]} with pytest.raises(TypeError, match="times should be a list of lists"): spikes = Spikes(times=([2.3456, 7.89], [4.2812, 93.2]), gids=spikegids, @@ -95,15 +83,6 @@ def test_spikes(): spikes = Spikes(times=[[2.3456, 7.89]], gids=spikegids, types=spiketypes) - with pytest.raises(ValueError, match="tstart and tstop must be of type " - "int or float"): - spikes = Spikes() - spikes.mean_rates(tstart=0.1, tstop='ABC', gid_dict={}) - - with pytest.raises(ValueError, match="tstop must be greater than tstart"): - spikes = Spikes() - spikes.mean_rates(tstart=0.1, tstop=-1.0, gid_dict={}) - spikes = Spikes(times=spiketimes, gids=spikegids, types=spiketypes) with pytest.raises(TypeError, match="spike_types should be str, " @@ -122,9 +101,35 @@ def test_spikes(): with pytest.raises(ValueError, match="No input types found for ABC"): spikes.plot_hist(spike_types='ABC', show=False) + with pytest.raises(ValueError, match="tstart and tstop must be of type " + "int or float"): + spikes.mean_rates(tstart=0.1, tstop='ABC', gid_dict=gid_dict) + + with pytest.raises(ValueError, match="tstop must be greater than tstart"): + spikes.mean_rates(tstart=0.1, tstop=-1.0, gid_dict=gid_dict) + with pytest.raises(ValueError, match="Invalid mean_type. Valid " "arguments include 'all', 'trial', or 'cell'."): - spikes.mean_rates(tstart=0.1, tstop=98.4, gid_dict={}, mean_type='ABC') + spikes.mean_rates(tstart=tstart, tstop=tstop, gid_dict=gid_dict, + mean_type='ABC') + + test_rate = (1 / (tstop - tstart)) * 1000 + + assert spikes.mean_rates(tstart, tstop, gid_dict) == { + 'L5_pyramidal': test_rate / 2, + 'L5_basket': test_rate / 2, + 'L2_pyramidal': test_rate / 2, + 'L2_basket': test_rate / 2} + assert spikes.mean_rates(tstart, tstop, gid_dict, mean_type='trial') == { + 'L5_pyramidal': [0.0, test_rate], + 'L5_basket': [0.0, test_rate], + 'L2_pyramidal': [test_rate, 0.0], + 'L2_basket': [test_rate, 0.0]} + assert spikes.mean_rates(tstart, tstop, gid_dict, mean_type='cell') == { + 'L5_pyramidal': [[0.0], [test_rate]], + 'L5_basket': [[0.0], [test_rate]], + 'L2_pyramidal': [[test_rate], [0.0]], + 'L2_basket': [[test_rate], [0.0]]} # Write spike file with no 'types' column # Check for gid_dict errors