Skip to content

Commit

Permalink
dbm: leverage MPI_Alloc_mem/MPI_Free_mem (#3413)
Browse files Browse the repository at this point in the history
  • Loading branch information
hfp committed May 13, 2024
1 parent be1f588 commit ee725ff
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 42 deletions.
11 changes: 5 additions & 6 deletions src/dbm/dbm_matrix.c
Expand Up @@ -15,7 +15,6 @@

#include "dbm_hyperparams.h"
#include "dbm_matrix.h"
#include "dbm_mpi.h"

/*******************************************************************************
* \brief Creates a new matrix.
Expand Down Expand Up @@ -143,14 +142,14 @@ void dbm_redistribute(const dbm_matrix_t *matrix, dbm_matrix_t *redist) {
// Compute displacements and allocate data buffers.
int send_displ[nranks + 1], recv_displ[nranks + 1];
send_displ[0] = recv_displ[0] = 0;
for (int irank = 1; irank < nranks + 1; irank++) {
for (int irank = 1; irank <= nranks; irank++) {
send_displ[irank] = send_displ[irank - 1] + send_count[irank - 1];
recv_displ[irank] = recv_displ[irank - 1] + recv_count[irank - 1];
}
const int total_send_count = send_displ[nranks];
const int total_recv_count = recv_displ[nranks];
double *data_send = malloc(total_send_count * sizeof(double));
double *data_recv = malloc(total_recv_count * sizeof(double));
double *data_send = dbm_mpi_alloc_mem(total_send_count * sizeof(double));
double *data_recv = dbm_mpi_alloc_mem(total_recv_count * sizeof(double));

// 2nd pass: Fill send_data.
int send_data_positions[nranks];
Expand Down Expand Up @@ -178,7 +177,7 @@ void dbm_redistribute(const dbm_matrix_t *matrix, dbm_matrix_t *redist) {
// 2nd communication: Exchange data.
dbm_mpi_alltoallv_double(data_send, send_count, send_displ, data_recv,
recv_count, recv_displ, comm);
free(data_send);
dbm_mpi_free_mem(data_send);

// 3rd pass: Unpack data.
dbm_clear(redist);
Expand All @@ -195,7 +194,7 @@ void dbm_redistribute(const dbm_matrix_t *matrix, dbm_matrix_t *redist) {
recv_data_pos += 2 + block_size;
}
assert(recv_data_pos == total_recv_count);
free(data_recv);
dbm_mpi_free_mem(data_recv);
}

/*******************************************************************************
Expand Down
5 changes: 3 additions & 2 deletions src/dbm/dbm_mempool.c
Expand Up @@ -14,6 +14,7 @@
#include "../offload/offload_library.h"
#include "../offload/offload_runtime.h"
#include "dbm_mempool.h"
#include "dbm_mpi.h"

/*******************************************************************************
* \brief Private routine for actually allocating system memory.
Expand All @@ -34,7 +35,7 @@ static void *actual_malloc(const size_t size, const bool on_device) {
(void)on_device; // mark used
#endif

void *memory = malloc(size);
void *memory = dbm_mpi_alloc_mem(size);
assert(memory != NULL);
return memory;
}
Expand All @@ -58,7 +59,7 @@ static void actual_free(void *memory, const bool on_device) {
(void)on_device; // mark used
#endif

free(memory);
dbm_mpi_free_mem(memory);
}

/*******************************************************************************
Expand Down
46 changes: 36 additions & 10 deletions src/dbm/dbm_mpi.c
Expand Up @@ -237,10 +237,10 @@ bool dbm_mpi_comms_are_similar(const dbm_mpi_comm_t comm1,
******************************************************************************/
void dbm_mpi_max_int(int *values, const int count, const dbm_mpi_comm_t comm) {
#if defined(__parallel)
int *recvbuf = malloc(count * sizeof(int));
void *recvbuf = dbm_mpi_alloc_mem(count * sizeof(int));
CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT, MPI_MAX, comm));
memcpy(values, recvbuf, count * sizeof(int));
free(recvbuf);
dbm_mpi_free_mem(recvbuf);
#else
(void)comm; // mark used
(void)values;
Expand All @@ -255,10 +255,10 @@ void dbm_mpi_max_int(int *values, const int count, const dbm_mpi_comm_t comm) {
void dbm_mpi_max_double(double *values, const int count,
const dbm_mpi_comm_t comm) {
#if defined(__parallel)
double *recvbuf = malloc(count * sizeof(double));
void *recvbuf = dbm_mpi_alloc_mem(count * sizeof(double));
CHECK(MPI_Allreduce(values, recvbuf, count, MPI_DOUBLE, MPI_MAX, comm));
memcpy(values, recvbuf, count * sizeof(double));
free(recvbuf);
dbm_mpi_free_mem(recvbuf);
#else
(void)comm; // mark used
(void)values;
Expand All @@ -272,10 +272,10 @@ void dbm_mpi_max_double(double *values, const int count,
******************************************************************************/
void dbm_mpi_sum_int(int *values, const int count, const dbm_mpi_comm_t comm) {
#if defined(__parallel)
int *recvbuf = malloc(count * sizeof(int));
void *recvbuf = dbm_mpi_alloc_mem(count * sizeof(int));
CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT, MPI_SUM, comm));
memcpy(values, recvbuf, count * sizeof(int));
free(recvbuf);
dbm_mpi_free_mem(recvbuf);
#else
(void)comm; // mark used
(void)values;
Expand All @@ -290,10 +290,10 @@ void dbm_mpi_sum_int(int *values, const int count, const dbm_mpi_comm_t comm) {
void dbm_mpi_sum_int64(int64_t *values, const int count,
const dbm_mpi_comm_t comm) {
#if defined(__parallel)
int64_t *recvbuf = malloc(count * sizeof(int64_t));
void *recvbuf = dbm_mpi_alloc_mem(count * sizeof(int64_t));
CHECK(MPI_Allreduce(values, recvbuf, count, MPI_INT64_T, MPI_SUM, comm));
memcpy(values, recvbuf, count * sizeof(int64_t));
free(recvbuf);
dbm_mpi_free_mem(recvbuf);
#else
(void)comm; // mark used
(void)values;
Expand All @@ -308,10 +308,10 @@ void dbm_mpi_sum_int64(int64_t *values, const int count,
void dbm_mpi_sum_double(double *values, const int count,
const dbm_mpi_comm_t comm) {
#if defined(__parallel)
double *recvbuf = malloc(count * sizeof(double));
void *recvbuf = dbm_mpi_alloc_mem(count * sizeof(double));
CHECK(MPI_Allreduce(values, recvbuf, count, MPI_DOUBLE, MPI_SUM, comm));
memcpy(values, recvbuf, count * sizeof(double));
free(recvbuf);
dbm_mpi_free_mem(recvbuf);
#else
(void)comm; // mark used
(void)values;
Expand Down Expand Up @@ -433,4 +433,30 @@ void dbm_mpi_alltoallv_double(const double *sendbuf, const int *sendcounts,
#endif
}

/*******************************************************************************
* \brief Wrapper around MPI_Alloc_mem.
* \author Hans Pabst
******************************************************************************/
void *dbm_mpi_alloc_mem(size_t size) {
void *result = NULL;
#if defined(__parallel)
CHECK(MPI_Alloc_mem((MPI_Aint)size, MPI_INFO_NULL, &result));
#else
result = malloc(size);
#endif
return result;
}

/*******************************************************************************
* \brief Wrapper around MPI_Free_mem.
* \author Hans Pabst
******************************************************************************/
void dbm_mpi_free_mem(void *mem) {
#if defined(__parallel)
CHECK(MPI_Free_mem(mem));
#else
free(mem);
#endif
}

// EOF
13 changes: 13 additions & 0 deletions src/dbm/dbm_mpi.h
Expand Up @@ -9,6 +9,7 @@
#define DBM_MPI_H

#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>

#if defined(__parallel)
Expand Down Expand Up @@ -183,6 +184,18 @@ void dbm_mpi_alltoallv_double(const double *sendbuf, const int *sendcounts,
const int *recvcounts, const int *rdispls,
const dbm_mpi_comm_t comm);

/*******************************************************************************
* \brief Wrapper around MPI_Alloc_mem.
* \author Hans Pabst
******************************************************************************/
void *dbm_mpi_alloc_mem(size_t size);

/*******************************************************************************
* \brief Wrapper around MPI_Free_mem.
* \author Hans Pabst
******************************************************************************/
void dbm_mpi_free_mem(void *ptr);

#endif

// EOF
15 changes: 8 additions & 7 deletions src/dbm/dbm_multiply_comm.c
Expand Up @@ -13,6 +13,7 @@

#include "dbm_hyperparams.h"
#include "dbm_mempool.h"
#include "dbm_mpi.h"

/*******************************************************************************
* \brief Returns the larger of two given integer (missing from the C standard)
Expand Down Expand Up @@ -356,8 +357,8 @@ static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
// Can not parallelize over packs because there might be too few of them.
for (int ipack = 0; ipack < nsend_packs; ipack++) {
// Allocate send buffers.
dbm_pack_block_t *blks_send =
malloc(nblks_send_per_pack[ipack] * sizeof(dbm_pack_block_t));
dbm_pack_block_t *blks_send = dbm_mpi_alloc_mem(nblks_send_per_pack[ipack] *
sizeof(dbm_pack_block_t));
double *data_send =
dbm_mempool_host_malloc(ndata_send_per_pack[ipack] * sizeof(double));

Expand All @@ -379,7 +380,7 @@ static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,

// 2nd communication: Exchange blocks.
dbm_pack_block_t *blks_recv =
malloc(nblocks_recv * sizeof(dbm_pack_block_t));
dbm_mpi_alloc_mem(nblocks_recv * sizeof(dbm_pack_block_t));
int blks_send_count_byte[nranks], blks_send_displ_byte[nranks];
int blks_recv_count_byte[nranks], blks_recv_displ_byte[nranks];
for (int i = 0; i < nranks; i++) { // TODO: this is ugly!
Expand All @@ -391,7 +392,7 @@ static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
dbm_mpi_alltoallv_byte(
blks_send, blks_send_count_byte, blks_send_displ_byte, blks_recv,
blks_recv_count_byte, blks_recv_displ_byte, dist->comm);
free(blks_send);
dbm_mpi_free_mem(blks_send);

// 3rd communication: Exchange data counts.
// TODO: could be computed from blks_recv.
Expand Down Expand Up @@ -428,7 +429,7 @@ static dbm_packed_matrix_t pack_matrix(const bool trans_matrix,
packed.max_nblocks = max_nblocks;
packed.max_data_size = max_data_size;
packed.recv_pack.blocks =
malloc(packed.max_nblocks * sizeof(dbm_pack_block_t));
dbm_mpi_alloc_mem(packed.max_nblocks * sizeof(dbm_pack_block_t));
packed.recv_pack.data =
dbm_mempool_host_malloc(packed.max_data_size * sizeof(double));

Expand Down Expand Up @@ -497,10 +498,10 @@ static dbm_pack_t *sendrecv_pack(const int itick, const int nticks,
* \author Ole Schuett
******************************************************************************/
static void free_packed_matrix(dbm_packed_matrix_t *packed) {
free(packed->recv_pack.blocks);
dbm_mpi_free_mem(packed->recv_pack.blocks);
dbm_mempool_free(packed->recv_pack.data);
for (int ipack = 0; ipack < packed->nsend_packs; ipack++) {
free(packed->send_packs[ipack].blocks);
dbm_mpi_free_mem(packed->send_packs[ipack].blocks);
dbm_mempool_free(packed->send_packs[ipack].data);
}
free(packed->send_packs);
Expand Down
15 changes: 7 additions & 8 deletions src/dbm/dbm_multiply_gpu.c
Expand Up @@ -34,19 +34,18 @@ void dbm_multiply_gpu_start(const int max_batch_size, const int nshards,

// Allocate device storage for batches.
const size_t size = nshards * max_batch_size * sizeof(dbm_task_t);
ctx->batches_dev = (dbm_task_t *)dbm_mempool_device_malloc(size);
ctx->batches_dev = dbm_mempool_device_malloc(size);

// Allocate and upload shards of result matrix C.
ctx->shards_c_dev =
(dbm_shard_gpu_t *)malloc(nshards * sizeof(dbm_shard_gpu_t));
ctx->shards_c_dev = malloc(nshards * sizeof(dbm_shard_gpu_t));
for (int i = 0; i < nshards; i++) {
const dbm_shard_t *shard_c_host = &ctx->shards_c_host[i];
dbm_shard_gpu_t *shard_c_dev = &ctx->shards_c_dev[i];
offloadStreamCreate(&shard_c_dev->stream);
shard_c_dev->data_size = shard_c_host->data_size;
shard_c_dev->data_allocated = shard_c_host->data_allocated;
shard_c_dev->data = (double *)dbm_mempool_device_malloc(
shard_c_dev->data_allocated * sizeof(double));
shard_c_dev->data =
dbm_mempool_device_malloc(shard_c_dev->data_allocated * sizeof(double));
offloadMemcpyAsyncHtoD(shard_c_dev->data, shard_c_host->data,
shard_c_dev->data_size * sizeof(double),
shard_c_dev->stream);
Expand All @@ -63,7 +62,7 @@ static void upload_pack(const dbm_pack_t *pack_host, dbm_pack_t *pack_dev,
const size_t size = pack_host->data_size * sizeof(double);
if (pack_dev->data_size < pack_host->data_size) {
dbm_mempool_free(pack_dev->data);
pack_dev->data = (double *)dbm_mempool_device_malloc(size);
pack_dev->data = dbm_mempool_device_malloc(size);
}
offloadMemcpyAsyncHtoD(pack_dev->data, pack_host->data, size, stream);
}
Expand Down Expand Up @@ -129,8 +128,8 @@ void dbm_multiply_gpu_process_batch(const int ntasks, const dbm_task_t *batch,
double *old_data_dev = shard_c_dev->data;
shard_c_dev->data_allocated =
ALLOCATION_FACTOR * shard_c_host->data_promised;
shard_c_dev->data = (double *)dbm_mempool_device_malloc(
shard_c_dev->data_allocated * sizeof(double));
shard_c_dev->data =
dbm_mempool_device_malloc(shard_c_dev->data_allocated * sizeof(double));
offloadMemcpyAsyncDtoD(shard_c_dev->data, old_data_dev,
shard_c_dev->data_size * sizeof(double),
shard_c_dev->stream);
Expand Down
6 changes: 0 additions & 6 deletions src/dbm/dbm_multiply_opencl.cl
Expand Up @@ -21,12 +21,6 @@

#define SINT short

#define DIVUP(A, B) (((A) + (B) - 1) / (B))
#define NUP(N, UP) (DIVUP(N, UP) * (UP))
#define BLR(N, BN) (NUP(N, BN) - (N))

#define IDX(I, J, M, N) ((int)(I) * (N) + (J))
#define IDT(I, J, M, N) IDX(J, I, N, M)
#define X(T, I) (T)->I /* task can be taken by value or by pointer */
#define XA(T) X(T, offset_a)
#define XB(T) X(T, offset_b)
Expand Down
7 changes: 4 additions & 3 deletions src/dbm/dbm_shard.c
Expand Up @@ -173,8 +173,8 @@ dbm_block_t *dbm_shard_promise_new_block(dbm_shard_t *shard, const int row,
// Grow blocks array if necessary.
if (shard->nblocks_allocated < shard->nblocks + 1) {
shard->nblocks_allocated = ALLOCATION_FACTOR * (shard->nblocks + 1);
shard->blocks =
realloc(shard->blocks, shard->nblocks_allocated * sizeof(dbm_block_t));
shard->blocks = (dbm_block_t *)realloc(
shard->blocks, shard->nblocks_allocated * sizeof(dbm_block_t));

// rebuild hashtable
free(shard->hashtable);
Expand Down Expand Up @@ -205,7 +205,8 @@ void dbm_shard_allocate_promised_blocks(dbm_shard_t *shard) {
// Reallocate data array if necessary.
if (shard->data_promised > shard->data_allocated) {
shard->data_allocated = ALLOCATION_FACTOR * shard->data_promised;
shard->data = realloc(shard->data, shard->data_allocated * sizeof(double));
shard->data =
(double *)realloc(shard->data, shard->data_allocated * sizeof(double));
}

// Zero new blocks.
Expand Down

0 comments on commit ee725ff

Please sign in to comment.