Skip to content

Commit

Permalink
DBM: allocate some comm buffer outside of loop
Browse files Browse the repository at this point in the history
* Allocate send/receive buffers for known maximum size.
* Avoids frequent allocations inside of loop.
  • Loading branch information
hfp committed May 22, 2024
1 parent 4e8a921 commit 81d1c75
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/dbm/dbm_multiply_comm.c
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,18 @@ static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
dist_ticks, nticks, nsend_packs, plans_per_pack,
nblks_send_per_pack, ndata_send_per_pack);

// Can not parallelize over packs because there might be too few of them.
for (int ipack = 0; ipack < nsend_packs; ipack++) {
// Allocate send buffers.
dbm_pack_block_t *blks_send = dbm_mpi_alloc_mem(nblks_send_per_pack[ipack] *
sizeof(dbm_pack_block_t));
double *data_send =
dbm_mempool_host_malloc(ndata_send_per_pack[ipack] * sizeof(double));
// Allocate send buffers for maximum number of blocks/data over all packs.
int nblks_send_max = 0, ndata_send_max = 0;
for (int ipack = 0; ipack < nsend_packs; ++ipack) {
nblks_send_max = imax(nblks_send_max, nblks_send_per_pack[ipack]);
ndata_send_max = imax(ndata_send_max, ndata_send_per_pack[ipack]);
}
dbm_pack_block_t *blks_send =
dbm_mpi_alloc_mem(nblks_send_max * sizeof(dbm_pack_block_t));
double *data_send = dbm_mempool_host_malloc(ndata_send_max * sizeof(double));

// Cannot parallelize over packs (there might be too few of them).
for (int ipack = 0; ipack < nsend_packs; ipack++) {
// Fill send buffers according to plans.
const int nranks = dist->nranks;
int blks_send_count[nranks], data_send_count[nranks];
Expand Down Expand Up @@ -392,7 +396,6 @@ static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
dbm_mpi_alltoallv_byte(
blks_send, blks_send_count_byte, blks_send_displ_byte, blks_recv,
blks_recv_count_byte, blks_recv_displ_byte, dist->comm);
dbm_mpi_free_mem(blks_send);

// 3rd communication: Exchange data counts.
// TODO: could be computed from blks_recv.
Expand All @@ -406,7 +409,6 @@ static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
dbm_mpi_alltoallv_double(data_send, data_send_count, data_send_displ,
data_recv, data_recv_count, data_recv_displ,
dist->comm);
dbm_mempool_free(data_send);

// Post-process received blocks and assemble them into a pack.
postprocess_received_blocks(nranks, dist_indices->nshards, nblocks_recv,
Expand All @@ -418,6 +420,10 @@ static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
packed.send_packs[ipack].data = data_recv;
}

// Deallocate send buffers.
dbm_mpi_free_mem(blks_send);
dbm_mempool_free(data_send);

// Allocate pack_recv.
int max_nblocks = 0, max_data_size = 0;
for (int ipack = 0; ipack < packed.nsend_packs; ipack++) {
Expand Down

0 comments on commit 81d1c75

Please sign in to comment.