Skip to content

Commit

Permalink
added thread information to call to get_nodes in connect function suc…
Browse files Browse the repository at this point in the history
…h that sibling containers return the correct device
  • Loading branch information
jakobj committed Feb 2, 2016
1 parent cf6468d commit 73ad9f8
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 71 deletions.
30 changes: 15 additions & 15 deletions nestkernel/conn_builder.cpp
Expand Up @@ -321,7 +321,7 @@ nest::ConnBuilder::change_connected_synaptic_elements( index sgid,
// check whether the source is on this mpi machine
if ( kernel().node_manager.is_local_gid( sgid ) )
{
Node* const source = kernel().node_manager.get_node( sgid );
Node* const source = kernel().node_manager.get_node( sgid, tid );
const thread source_thread = source->get_thread();

// check whether the source is on our thread
Expand All @@ -339,7 +339,7 @@ nest::ConnBuilder::change_connected_synaptic_elements( index sgid,
}
else
{
Node* const target = kernel().node_manager.get_node( tgid );
Node* const target = kernel().node_manager.get_node( tgid, tid );
const thread target_thread = target->get_thread();
// check whether the target is on our thread
if ( tid != target_thread )
Expand Down Expand Up @@ -551,7 +551,7 @@ nest::OneToOneBuilder::connect_()
continue;
}

Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();

// check whether the target is on our thread
Expand Down Expand Up @@ -610,7 +610,7 @@ nest::OneToOneBuilder::disconnect_()
continue;
}

Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();

// check whether the target is on our thread
Expand Down Expand Up @@ -672,7 +672,7 @@ nest::OneToOneBuilder::sp_connect_()
skip_conn_parameter_( tid );
continue;
}
Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();

single_connect_( *sgid, *target, target_thread, rng );
Expand Down Expand Up @@ -719,7 +719,7 @@ nest::OneToOneBuilder::sp_disconnect_()

if ( !change_connected_synaptic_elements( *sgid, *tgid, tid, -1 ) )
continue;
Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();

single_disconnect_( *sgid, *target, target_thread );
Expand Down Expand Up @@ -760,7 +760,7 @@ nest::AllToAllBuilder::connect_()
continue;
}

Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();

// check whether the target is on our thread
Expand Down Expand Up @@ -831,7 +831,7 @@ nest::AllToAllBuilder::sp_connect_()
skip_conn_parameter_( tid );
continue;
}
Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();
single_connect_( *sgid, *target, target_thread, rng );
}
Expand Down Expand Up @@ -873,7 +873,7 @@ nest::AllToAllBuilder::disconnect_()
continue;
}

Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();

// check whether the target is on our thread
Expand Down Expand Up @@ -930,7 +930,7 @@ nest::AllToAllBuilder::sp_disconnect_()
skip_conn_parameter_( tid );
continue;
}
Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();
single_disconnect_( *sgid, *target, target_thread );
}
Expand Down Expand Up @@ -982,7 +982,7 @@ nest::FixedInDegreeBuilder::connect_()
if ( not kernel().node_manager.is_local_gid( *tgid ) )
continue;

Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();

// check whether the target is on our thread
Expand Down Expand Up @@ -1084,7 +1084,7 @@ nest::FixedOutDegreeBuilder::connect_()
if ( not kernel().node_manager.is_local_gid( *tgid ) )
continue;

Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();

// check whether the target is on our thread
Expand Down Expand Up @@ -1226,7 +1226,7 @@ nest::FixedTotalNumberBuilder::connect_()
// targets_on_vp vector
const long_t tgid = targets_on_vp[ vp_id ][ t_index ];

Node* const target = kernel().node_manager.get_node( tgid );
Node* const target = kernel().node_manager.get_node( tgid, tid );
const thread target_thread = target->get_thread();

if ( autapses_ or sgid != tgid )
Expand Down Expand Up @@ -1277,7 +1277,7 @@ nest::BernoulliBuilder::connect_()
if ( not kernel().node_manager.is_local_gid( *tgid ) )
continue;

Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();

// check whether the target is on our thread
Expand Down Expand Up @@ -1401,7 +1401,7 @@ nest::SPBuilder::connect_( GIDCollection sources, GIDCollection targets )
skip_conn_parameter_( tid );
continue;
}
Node* const target = kernel().node_manager.get_node( *tgid );
Node* const target = kernel().node_manager.get_node( *tgid, tid );
const thread target_thread = target->get_thread();

single_connect_( *sgid, *target, target_thread, rng );
Expand Down
147 changes: 103 additions & 44 deletions nestkernel/connection_builder_manager.cpp
Expand Up @@ -339,31 +339,50 @@ nest::ConnectionBuilderManager::connect( index sgid,
double_t d,
double_t w )
{
Node* const source = kernel().node_manager.get_node( sgid, target_thread );
const thread tid = kernel().vp_manager.get_thread_id();
Node* source = kernel().node_manager.get_node( sgid, target_thread );

// normal nodes and devices with proxies
// target is a normal node or device with proxies
if ( target->has_proxies() )
{
connect_( *source, *target, sgid, target_thread, syn, d, w );
}
else if ( target->local_receiver() ) // normal devices
else if ( target->local_receiver() ) // target is a normal device
{
// make sure source is on this MPI rank
if ( source->is_proxy() )
{
return;
}

// make sure connections are only created on the thread of the device
if ( ( source->get_thread() != target_thread ) && ( source->has_proxies() ) )
{
target_thread = source->get_thread();
target = kernel().node_manager.get_node( target->get_gid(), target_thread );
return;
}

connect_( *source, *target, sgid, target_thread, syn, d, w );
if ( source->has_proxies() ) // normal neuron->device connection
{
connect_( *source, *target, sgid, target_thread, syn, d, w );
}
else // create device->device connections on suggested thread of target
{
target_thread = kernel().vp_manager.suggest_vp( target->get_gid() );
if ( target_thread == tid )
{
source = kernel().node_manager.get_node( sgid, target_thread );
target = kernel().node_manager.get_node( target->get_gid(), target_thread );
connect_( *source, *target, sgid, target_thread, syn, d, w );
}
}
}
else // globally receiving devices iterate over all target threads
else // globally receiving devices iterate over all target threads, e.g., volume transmitter
{
if ( !source->has_proxies() ) // we do not allow to connect a device to a global receiver at the
// moment
// we do not allow to connect a device to a global receiver at the moment
if ( not source->has_proxies() )
{
return;
}
const thread n_threads = kernel().vp_manager.get_num_threads();
for ( thread t = 0; t < n_threads; t++ )
{
Expand All @@ -383,31 +402,50 @@ nest::ConnectionBuilderManager::connect( index sgid,
double_t d,
double_t w )
{
Node* const source = kernel().node_manager.get_node( sgid, target_thread );
const thread tid = kernel().vp_manager.get_thread_id();
Node* source = kernel().node_manager.get_node( sgid, target_thread );

// normal nodes and devices with proxies
// target is a normal node or device with proxies
if ( target->has_proxies() )
{
connect_( *source, *target, sgid, target_thread, syn, params, d, w );
}
else if ( target->local_receiver() ) // normal devices
else if ( target->local_receiver() ) // target is a normal device
{
// make sure source is on this MPI rank
if ( source->is_proxy() )
{
return;
}

// make sure connections are only created on the thread of the device
if ( ( source->get_thread() != target_thread ) && ( source->has_proxies() ) )
{
target_thread = source->get_thread();
target = kernel().node_manager.get_node( target->get_gid(), target_thread );
return;
}

connect_( *source, *target, sgid, target_thread, syn, params, d, w );
if ( source->has_proxies() ) // normal neuron->device connection
{
connect_( *source, *target, sgid, target_thread, syn, d, w );
}
else // create device->device connections on suggested thread of target
{
target_thread = kernel().vp_manager.suggest_vp( target->get_gid() );
if ( target_thread == tid )
{
source = kernel().node_manager.get_node( sgid, target_thread );
target = kernel().node_manager.get_node( target->get_gid(), target_thread );
connect_( *source, *target, sgid, target_thread, syn, d, w );
}
}
}
else // globally receiving devices iterate over all target threads
else // globally receiving devices iterate over all target threads, e.g., volume transmitter
{
if ( !source->has_proxies() ) // we do not allow to connect a device to a global receiver at the
// moment
// we do not allow to connect a device to a global receiver at the moment
if ( not source->has_proxies() )
{
return;
}
const thread n_threads = kernel().vp_manager.get_num_threads();
for ( thread t = 0; t < n_threads; t++ )
{
Expand All @@ -419,53 +457,72 @@ nest::ConnectionBuilderManager::connect( index sgid,

// gid gid dict
bool
nest::ConnectionBuilderManager::connect( index source_id,
index target_id,
nest::ConnectionBuilderManager::connect( index sgid,
index tgid,
DictionaryDatum& params,
index syn )
{
const thread tid = kernel().vp_manager.get_thread_id();

if ( !kernel().node_manager.is_local_gid( target_id ) )
// make sure target is on this MPI rank
if ( !kernel().node_manager.is_local_gid( tgid ) )
{
return false;
}

Node* target_ptr = kernel().node_manager.get_node( target_id );

// target_thread defaults to 0 for devices
thread target_thread = target_ptr->get_thread();

Node* source_ptr = kernel().node_manager.get_node( source_id, target_thread );
Node* target = kernel().node_manager.get_node( tgid, tid );
thread target_thread = target->get_thread();
Node* source = kernel().node_manager.get_node( sgid, target_thread );

// normal nodes and devices with proxies
if ( target_ptr->has_proxies() )
// target is a normal node or device with proxies
if ( target->has_proxies() )
{
connect_( *source_ptr, *target_ptr, source_id, target_thread, syn, params );
connect_( *source, *target, sgid, target_thread, syn, params );
}
else if ( target_ptr->local_receiver() ) // normal devices
else if ( target->local_receiver() ) // target is a normal device
{
if ( source_ptr->is_proxy() )
// make sure source is on this MPI rank
if ( source->is_proxy() )
{
return false;
}

if ( ( source_ptr->get_thread() != target_thread ) && ( source_ptr->has_proxies() ) )
// make sure connections are only created on the thread of the device
if ( ( source->get_thread() != target_thread ) && ( source->has_proxies() ) )
{
target_thread = source_ptr->get_thread();
target_ptr = kernel().node_manager.get_node( target_id, target_thread );
return false;
}

if ( source->has_proxies() ) // normal neuron->device connection
{
connect_( *source, *target, sgid, target_thread, syn, params );
}
else // create device->device connections on suggested thread of target
{
target_thread = kernel().vp_manager.suggest_vp( target->get_gid() );
if ( target_thread == tid )
{
source = kernel().node_manager.get_node( sgid, target_thread );
target = kernel().node_manager.get_node( target->get_gid(), target_thread );
connect_( *source, *target, sgid, target_thread, syn, params );
}
}

connect_( *source_ptr, *target_ptr, source_id, target_thread, syn, params );
}
else // globally receiving devices iterate over all target threads
else // globally receiving devices iterate over all target threads, e.g., volume transmitter
{
if ( !source_ptr->has_proxies() ) // we do not allow to connect a device to a global receiver at
// the moment
// we do not allow to connect a device to a global receiver at the moment
if ( not source->has_proxies() )
{
return false;
}
const thread n_threads = kernel().vp_manager.get_num_threads();
for ( thread t = 0; t < n_threads; t++ )
{
target_ptr = kernel().node_manager.get_node( target_id, t );
connect_( *source_ptr, *target_ptr, source_id, t, syn, params );
target = kernel().node_manager.get_node( tgid, t );
connect_( *source, *target, sgid, t, syn, params );
}
}

// We did not exit prematurely due to proxies, so we have connected.
return true;
}
Expand Down Expand Up @@ -587,6 +644,8 @@ nest::ConnectionBuilderManager::divergent_connect( index source_id,
const TokenArray& delays,
index syn )
{
const thread tid = kernel().vp_manager.get_thread_id();

bool complete_wd_lists = ( target_ids.size() == weights.size() && weights.size() != 0
&& weights.size() == delays.size() );
bool short_wd_lists =
Expand Down Expand Up @@ -632,7 +691,7 @@ nest::ConnectionBuilderManager::divergent_connect( index source_id,
{
index gid = getValue< long >( target_ids[ i ] );
if ( kernel().node_manager.is_local_gid( gid ) )
targets.push_back( kernel().node_manager.get_node( gid ) );
targets.push_back( kernel().node_manager.get_node( gid, tid ) );
}

for ( index i = 0; i < targets.size(); ++i )
Expand Down Expand Up @@ -1810,4 +1869,4 @@ nest::ConnectionBuilderManager::get_targets( std::vector< index > sources,
}
}
}
}
}

0 comments on commit 73ad9f8

Please sign in to comment.