Skip to content

Commit

Permalink
Refactor implementation of operator+
Browse files Browse the repository at this point in the history
  • Loading branch information
heplesser committed Apr 18, 2024
1 parent 46195d6 commit 82d25d2
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 107 deletions.
152 changes: 72 additions & 80 deletions nestkernel/node_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,30 +124,39 @@ nc_const_iterator::nc_const_iterator( NodeCollectionPTR collection_ptr,
rank_or_vp_ ) ); )
}

void
nc_const_iterator::advance_composite_iterator_( size_t n )
size_t
nc_const_iterator::find_next_within_part_( size_t n ) const
{
// See if we can do a simple step, i.e., stay in current part
const size_t new_element_idx = element_idx_ + n * step_;
size_t primitive_size = composite_collection_->parts_[ part_idx_ ].size();
if ( new_element_idx < primitive_size )

if ( primitive_collection_ )
{
// Still in old part, check if we have passed end
if ( part_idx_ == composite_collection_->last_part_ and new_element_idx > composite_collection_->last_elem_ )
// Avoid running over end of collection
return std::min( new_element_idx, primitive_collection_->size() );
}

if ( new_element_idx < composite_collection_->parts_[ part_idx_ ].size() )
{
if ( composite_collection_->valid_idx_( part_idx_, new_element_idx ) )
{
// set iterator to uniquely defined end() element
assert( part_idx_ == composite_collection_->last_part_ );
element_idx_ = composite_collection_->last_elem_ + 1;
return;
// We have found an element in the part
return new_element_idx;
}
else
{
element_idx_ = new_element_idx;
return;
// We have reached the end of the node collection, return index for end iterator
assert( part_idx_ == composite_collection_->last_part_ );
return composite_collection_->last_elem_ + 1;
}
}

// We know that we need to look in another part
// No new element found in this part and collection not exhausted
return element_idx_;
}

void
nc_const_iterator::advance_global_iter_to_new_part_( size_t n )
{
if ( part_idx_ == composite_collection_->last_part_ )
{
// No more parts, set to end()
Expand All @@ -156,38 +165,58 @@ nc_const_iterator::advance_composite_iterator_( size_t n )
return;
}

// At least one more part available
// Find new position counting from beginning of node collection
const auto part_abs_begin = part_idx_ == 0 ? 0 : composite_collection_->cumul_abs_size_[ part_idx_ - 1 ];
const auto new_abs_idx = part_abs_begin + element_idx_ + n * composite_collection_->stride_;

// Confirm that new position is in a new part
assert( new_abs_idx >= composite_collection_->cumul_abs_size_[ part_idx_ ] );

if ( kind_ == NCIteratorKind::GLOBAL )
// Move to part that contains new position
do
{
// Simple stepping scheme without phase adjustment
std::tie( part_idx_, element_idx_ ) = composite_collection_->find_next_part_( part_idx_, element_idx_, n );
++part_idx_;
} while ( part_idx_ <= composite_collection_->last_part_
and composite_collection_->cumul_abs_size_[ part_idx_ ] <= new_abs_idx );

// In case we did not find a solution, set to end()
if ( part_idx_ == invalid_index )
{
part_idx_ = composite_collection_->last_part_;
element_idx_ = composite_collection_->last_elem_ + 1;
}
// If there is another element, it must have this index
element_idx_ = new_abs_idx - composite_collection_->cumul_abs_size_[ part_idx_ - 1 ];

if ( not composite_collection_->valid_idx_( part_idx_, element_idx_ ) )
{
// Node collection exhausted
part_idx_ = composite_collection_->last_part_;
element_idx_ = composite_collection_->last_elem_ + 1;
}
else
}

void
nc_const_iterator::advance_local_iter_to_new_part_( size_t n )
{
// We know that we need to look in another part
if ( part_idx_ == composite_collection_->last_part_ )
{
// We are stepping over rank- or thread-specific elements and need phase adjustment.
//
// Current part is exhausted. We need to find the next part containing
// and element compatible with stride_ and then perform phase adjustment.
assert( n == 1 );
// No more parts, set to end()
part_idx_ = composite_collection_->last_part_;
element_idx_ = composite_collection_->last_elem_ + 1;
return;
}

// {RANK,THREAD}_LOCAL iterators require phase adjustment
// which is feasible only for single steps, so unroll
for ( size_t k = 0; k < n; ++k )
{
// Find next part that has element in underlying GLOBAL stride
do
{
++part_idx_;
} while ( part_idx_ < composite_collection_->cumul_abs_size_.size()
} while ( part_idx_ <= composite_collection_->last_part_
and composite_collection_->first_in_part_[ part_idx_ ] == invalid_index );

if ( part_idx_ < composite_collection_->cumul_abs_size_.size() )
if ( part_idx_ <= composite_collection_->last_part_ )
{
// We have a candidate part and a first valid element in it
// We have a candidate part and a first valid element in it, so we perform phase adjustment

assert( composite_collection_->first_in_part_[ part_idx_ ] != invalid_index );
element_idx_ = composite_collection_->first_in_part_[ part_idx_ ];

Expand Down Expand Up @@ -220,21 +249,20 @@ nc_const_iterator::advance_composite_iterator_( size_t n )
assert( false ); // should not be here, otherwise kind_ is inconsistent
break;
}

// In case we did not find a solution in phase adjustment, set to end()
if ( part_idx_ == invalid_index )
{
part_idx_ = composite_collection_->last_part_;
element_idx_ = composite_collection_->last_elem_ + 1;
}
}
else
{
// Node collection exhausted, set to end()
part_idx_ = composite_collection_->last_part_;
element_idx_ = composite_collection_->last_elem_ + 1;
break; // no more parts to search
}
} // else kind_ == GLOBAL
}

// In case we did not find a solution in phase adjustment, set to end()
if ( part_idx_ == invalid_index or not composite_collection_->valid_idx_( part_idx_, element_idx_ ) )
{
// Node collection exhausted, set to end()
part_idx_ = composite_collection_->last_part_;
element_idx_ = composite_collection_->last_elem_ + 1;
}
}

void
Expand Down Expand Up @@ -1152,36 +1180,6 @@ NodeCollectionComposite::specific_local_begin_( size_t period,
return { invalid_index, invalid_index };
}

std::pair< size_t, size_t >
NodeCollectionComposite::find_next_part_( size_t part_idx, size_t element_idx, size_t n ) const
{
assert( part_idx < last_part_ );

// Find new position counting from beginning of node collection
const auto part_abs_begin = part_idx == 0 ? 0 : cumul_abs_size_[ part_idx - 1 ];
const auto new_abs_idx = part_abs_begin + element_idx + n * stride_;

// Confirm that new position is in a new part
assert( new_abs_idx >= cumul_abs_size_[ part_idx ] );

// Move to part that contains new position
do
{
++part_idx;
} while ( part_idx < cumul_abs_size_.size() and cumul_abs_size_[ part_idx ] <= new_abs_idx );

if ( part_idx >= cumul_abs_size_.size() or first_in_part_[ part_idx ] == invalid_index )
{
// Either we checked all parts or the part containing the index contains no element
// compatible with our stride
return { invalid_index, invalid_index };
}

// We have found a new element
return { part_idx, new_abs_idx - cumul_abs_size_[ part_idx - 1 ] };
}


size_t
NodeCollectionComposite::gid_to_vp_( size_t gid )
{
Expand Down Expand Up @@ -1308,12 +1306,6 @@ NodeCollectionComposite::merge_parts_( std::vector< NodeCollectionPrimitive >& p
}
}

bool
NodeCollectionComposite::contains( const size_t node_id ) const
{
return get_nc_index( node_id ) != -1;
}

long
NodeCollectionComposite::get_nc_index( const size_t node_id ) const
{
Expand Down
74 changes: 47 additions & 27 deletions nestkernel/node_collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,19 @@ class nc_const_iterator
NCIteratorKind kind = NCIteratorKind::GLOBAL );

/**
* Advance composite iterator by n elements, taking stride into account.
* Return element_idx_ for next element if within part. Returns current element_idx_ otherwise.
*/
void advance_composite_iterator_( size_t n );
size_t find_next_within_part_( size_t n ) const;

/**
* Advance composite GLOBAL iterator by n elements, taking stride into account.
*/
void advance_global_iter_to_new_part_( size_t n );

/**
* Advance composite {THREAD,RANK}_LOCAL iterator by n elements, taking stride into account.
*/
void advance_local_iter_to_new_part_( size_t n );

public:
using iterator_category = std::forward_iterator_tag;
Expand Down Expand Up @@ -764,6 +774,11 @@ class NodeCollectionComposite : public NodeCollection
size_t start_offset,
gid_to_phase_fcn_ period_first_node ) const;

/**
* Return true if part_idx/element_idx pair indicates element of collection
*/
bool valid_idx_( const size_t part_idx, const size_t element_idx ) const;

/**
* Find next part and offset in it after moving beyond previous part, based on stride.
*
Expand All @@ -781,6 +796,7 @@ class NodeCollectionComposite : public NodeCollection
//! helper for rank_local_begin/compsite_update_indices
static size_t gid_to_rank_( size_t gid );


public:
/**
* Create a composite from a primitive, with boundaries and step length.
Expand Down Expand Up @@ -905,25 +921,30 @@ nc_const_iterator::operator+=( const size_t n )
{
assert( kind_ != NCIteratorKind::END );

if ( primitive_collection_ )
if ( n == 0 )
{
return *this;
}

const auto new_element_idx = find_next_within_part_( n );

// For a primitive collection, we either have a new element or are at the end
// For a composite collection, we may need to search through further parts,
// which is signalled by new_element_idx == element_idx_
if ( primitive_collection_ or new_element_idx != element_idx_ )
{
// Guard against passing end ( size() gives element_index_ for end() iterator )
element_idx_ = std::min( element_idx_ + n * step_, primitive_collection_->size() );
element_idx_ = new_element_idx;
}
else
{
// We did not find a new element in the current part and have not exhausted the collection
if ( kind_ == NCIteratorKind::GLOBAL )
{
advance_composite_iterator_( n );
advance_global_iter_to_new_part_( n );
}
else
{
// {RANK,THREAD}_LOCAL iterators require phase adjustment
// which is feasible only for single steps, so unroll
for ( size_t k = 0; k < n; ++k )
{
advance_composite_iterator_( 1 );
}
advance_local_iter_to_new_part_( n );
}
}

Expand All @@ -940,21 +961,7 @@ nc_const_iterator::operator+( const size_t n ) const
inline nc_const_iterator&
nc_const_iterator::operator++()
{
assert( kind_ != NCIteratorKind::END );

// This code is partial duplication of operator+(n), but because it is much simpler
// for composite collections than operator+(n), we code it explicitly here instead
// of redirecting to operator+(1).
if ( primitive_collection_ )
{
// Guard against passing end ( size() gives element_index_ for end() iterator )
element_idx_ = std::min( element_idx_ + step_, primitive_collection_->size() );
}
else
{
advance_composite_iterator_( 1 );
}

( *this ) += 1;
return *this;
}

Expand Down Expand Up @@ -1186,6 +1193,19 @@ NodeCollectionComposite::empty() const
// Composite NodeCollections can never be empty.
return false;
}

inline bool
NodeCollectionComposite::contains( const size_t node_id ) const
{
return get_nc_index( node_id ) != -1;
}

inline bool
NodeCollectionComposite::valid_idx_( const size_t part_idx, const size_t element_idx ) const
{
return part_idx < last_part_ or ( part_idx == last_part_ and element_idx <= last_elem_ );
}

} // namespace nest

#endif /* #ifndef NODE_COLLECTION_H */

0 comments on commit 82d25d2

Please sign in to comment.