diff --git a/nestkernel/connection_manager.cpp b/nestkernel/connection_manager.cpp index 639f15c481..99630ee005 100644 --- a/nestkernel/connection_manager.cpp +++ b/nestkernel/connection_manager.cpp @@ -578,39 +578,51 @@ nest::ConnectionManager::connect_arrays( long* sources, sw_construction_connect.start(); // Mapping pointers to the first parameter value of each parameter to their respective names. - std::map< Name, double* > param_pointers; + // The bool indicates whether the value is an integer or not, and is determined at a later point. + std::map< Name, std::pair< double*, bool > > param_pointers; if ( p_keys.size() != 0 ) { size_t i = 0; for ( auto& key : p_keys ) { // Shifting the pointer to the first value of the parameter. - param_pointers[ key ] = p_values + i * n; + param_pointers[ key ] = std::make_pair( p_values + i * n, false ); ++i; } } + const auto synapse_model_id = kernel().model_manager.get_synapse_model_id( syn_model ); + const auto syn_model_defaults = kernel().model_manager.get_connector_defaults( synapse_model_id ); + // Dictionary holding additional synapse parameters, passed to the connect call. std::vector< DictionaryDatum > param_dicts; param_dicts.reserve( kernel().vp_manager.get_num_threads() ); for ( thread i = 0; i < kernel().vp_manager.get_num_threads(); ++i ) { param_dicts.emplace_back( new Dictionary ); - for ( auto& param_keys : p_keys ) + for ( auto& param_key : p_keys ) { - if ( Name( param_keys ) == names::receptor_type ) + const Name param_name = param_key; // Convert string to Name + // Check that the parameter exists for the synapse model. + const auto syn_model_default_it = syn_model_defaults->find( param_name ); + if ( syn_model_default_it == syn_model_defaults->end() ) { - ( *param_dicts[ i ] )[ param_keys ] = Token( new IntegerDatum( 0 ) ); + throw BadParameter( syn_model + " does not have parameter " + param_key ); + } + + // If the default value is an integer, the synapse parameter must also be an integer. + if ( dynamic_cast< IntegerDatum* >( syn_model_default_it->second.datum() ) != nullptr ) + { + param_pointers[ param_key ].second = true; + ( *param_dicts[ i ] )[ param_key ] = Token( new IntegerDatum( 0 ) ); } else { - ( *param_dicts[ i ] )[ param_keys ] = Token( new DoubleDatum( 0.0 ) ); + ( *param_dicts[ i ] )[ param_key ] = Token( new DoubleDatum( 0.0 ) ); } } } - const index synapse_model_id = kernel().model_manager.get_synapse_model_id( syn_model ); - // Increments pointers to weight and delay, if they are specified. auto increment_wd = [weights, delays]( decltype( weights ) & w, decltype( delays ) & d ) { if ( weights != nullptr ) @@ -640,7 +652,7 @@ nest::ConnectionManager::connect_arrays( long* sources, auto d = delays; double weight_buffer = numerics::nan; double delay_buffer = numerics::nan; - int index_counter = 0; + int index_counter = 0; // Index of the current connection, for connection parameters for ( ; s != sources + n; ++s, ++t, ++index_counter ) { @@ -674,16 +686,20 @@ nest::ConnectionManager::connect_arrays( long* sources, for ( auto& param_pointer_pair : param_pointers ) { // Increment the pointer to the parameter value. - auto* param = param_pointer_pair.second + index_counter; + const auto param_pointer = param_pointer_pair.second.first; + const auto is_int = param_pointer_pair.second.second; + auto* param = param_pointer + index_counter; - // Receptor type must be an integer. - if ( param_pointer_pair.first == names::receptor_type ) + // Integer parameters are stored as IntegerDatums. + if ( is_int ) { const auto rtype_as_long = static_cast< long >( *param ); if ( *param > 1L << 31 or std::abs( *param - rtype_as_long ) > 0 ) // To avoid rounding errors { - throw BadParameter( "Receptor types must be integers." ); + const auto msg = std::string( "Expected integer value for " ) + param_pointer_pair.first.toString() + + ", but got double."; + throw BadParameter( msg ); } // Change value of dictionary entry without allocating new datum. diff --git a/testsuite/pytests/test_connect_arrays.py b/testsuite/pytests/test_connect_arrays.py index e31598712c..3c2465729c 100644 --- a/testsuite/pytests/test_connect_arrays.py +++ b/testsuite/pytests/test_connect_arrays.py @@ -250,6 +250,30 @@ def test_connect_arrays_additional_synspec_params(self): self.assertEqual(c.alpha, a) self.assertEqual(c.tau, tau) + def test_connect_arrays_syn_lbl(self): + """Connecting NumPy arrays with synapse label""" + n = 10 + nest.Create('iaf_psc_alpha', n) + sources = np.arange(1, n+1, dtype=np.uint64) + targets = self.non_unique + weights = np.ones(len(sources)) + delays = np.ones(len(sources)) + syn_model = 'static_synapse_lbl' + syn_label = 2 + + nest.Connect(sources, targets, conn_spec='one_to_one', + syn_spec={'weight': weights, 'delay': delays, 'synapse_model': syn_model, + 'synapse_label': syn_label}) + + conns = nest.GetConnections() + + for s, t, w, d, lbl, c in zip(sources, targets, weights, delays, len(conns)*[syn_label], conns): + self.assertEqual(c.source, s) + self.assertEqual(c.target, t) + self.assertEqual(c.weight, w) + self.assertEqual(c.delay, d) + self.assertEqual(c.synapse_label, lbl) + def test_connect_arrays_float_rtype(self): """Raises exception when not using integer value for receptor_type""" n = 10