Skip to content

Commit

Permalink
DBM: Switch to 2D sharding for better OpenMP data locality
Browse files Browse the repository at this point in the history
  • Loading branch information
oschuett committed Aug 26, 2022
1 parent d084757 commit 245edd6
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 226 deletions.
45 changes: 41 additions & 4 deletions src/dbm/dbm_distribution.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,25 @@
/*----------------------------------------------------------------------------*/

#include <assert.h>
#include <math.h>
#include <omp.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdlib.h>
#include <string.h>

#include "dbm_distribution.h"
#include "dbm_hyperparams.h"

/*******************************************************************************
* \brief Private routine for creating a new one dimensional distribution.
* \author Ole Schuett
******************************************************************************/
static void dbm_dist_1d_new(dbm_dist_1d_t *dist, const int length,
const int coords[length],
const dbm_mpi_comm_t comm) {
const int coords[length], const dbm_mpi_comm_t comm,
const int nshards) {
dist->comm = comm;
dist->nshards = nshards;
dist->my_rank = dbm_mpi_comm_rank(comm);
dist->nranks = dbm_mpi_comm_size(comm);
dist->length = length;
Expand Down Expand Up @@ -63,6 +66,36 @@ static void dbm_dist_1d_free(dbm_dist_1d_t *dist) {
dbm_mpi_comm_free(&dist->comm);
}

/*******************************************************************************
* \brief Returns the larger of two given integer (missing from the C standard)
* \author Ole Schuett
******************************************************************************/
static inline int imax(int x, int y) { return (x > y ? x : y); }

/*******************************************************************************
* \brief Private routine for TODO
* \author Ole Schuett
******************************************************************************/
static int find_best_nrow_shards(const int nshards, const int nrows,
const int ncols) {
const double target = (double)imax(nrows, 1) / (double)imax(ncols, 1);
int best_nrow_shards = nshards;
double best_error = fabs(target - best_nrow_shards);

for (int nrow_shards = 1; nrow_shards <= nshards; nrow_shards++) {
const int ncol_shards = nshards / nrow_shards;
if (nrow_shards * ncol_shards != nshards)
continue; // Not a factor of nshards.
const double ratio = (double)nrow_shards / (double)ncol_shards;
const double error = fabs(log(target / ratio));
if (error < best_error) {
best_error = error;
best_nrow_shards = nrow_shards;
}
}
return best_nrow_shards;
}

/*******************************************************************************
* \brief Creates a new two dimensional distribution.
* \author Ole Schuett
Expand All @@ -85,8 +118,12 @@ void dbm_distribution_new(dbm_distribution_t **dist_out, const int fortran_comm,
const int col_dim_remains[2] = {0, 1};
const dbm_mpi_comm_t col_comm = dbm_mpi_cart_sub(dist->comm, col_dim_remains);

dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm);
dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm);
const int nshards = SHARDS_PER_THREAD * omp_get_max_threads();
const int nrow_shards = find_best_nrow_shards(nshards, nrows, ncols);
const int ncol_shards = nshards / nrow_shards;

dbm_dist_1d_new(&dist->rows, nrows, row_dist, row_comm, nrow_shards);
dbm_dist_1d_new(&dist->cols, ncols, col_dist, col_comm, ncol_shards);

assert(*dist_out == NULL);
*dist_out = dist;
Expand Down
1 change: 1 addition & 0 deletions src/dbm/dbm_distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ typedef struct {
dbm_mpi_comm_t comm; // 1D communicator
int nranks;
int my_rank;
int nshards; // Number of shards for distributing blocks across threads.
} dbm_dist_1d_t;

/*******************************************************************************
Expand Down
5 changes: 0 additions & 5 deletions src/dbm/dbm_hyperparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ static const int BATCH_NUM_BUCKETS = 1000;
static const int INITIAL_NBLOCKS_ALLOCATED = 100;
static const int INITIAL_DATA_ALLOCATED = 1024;

// Choosing size as power of two allows to replace modulo with bitwise AND.
static const int PACK_HASH_SIZE = 1024;
static const int PACK_HASH_MASK = 1023; // PACK_HASH_SIZE - 1
static const int PACK_HASH_PRIME = 509; // Closest prime to PACK_HASH_SIZE / 2.

#endif

// EOF
52 changes: 25 additions & 27 deletions src/dbm/dbm_matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ void dbm_create(dbm_matrix_t **matrix_out, dbm_distribution_t *dist,
matrix->col_sizes = malloc(size);
memcpy(matrix->col_sizes, col_sizes, size);

matrix->nshards = SHARDS_PER_THREAD * omp_get_max_threads();
matrix->shards = malloc(matrix->nshards * sizeof(dbm_shard_t));
matrix->shards = malloc(dbm_get_num_shards(matrix) * sizeof(dbm_shard_t));
#pragma omp parallel for
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
dbm_shard_init(&matrix->shards[ishard]);
}

Expand All @@ -69,7 +68,7 @@ void dbm_release(dbm_matrix_t *matrix) {
free(matrix->name);
free(matrix->row_sizes);
free(matrix->col_sizes);
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
dbm_shard_release(&matrix->shards[ishard]);
}
free(matrix->shards);
Expand All @@ -93,11 +92,10 @@ void dbm_copy(dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b) {
assert(matrix_b->col_sizes[i] == matrix_a->col_sizes[i]);
}

assert(matrix_a->nshards == matrix_b->nshards);
assert(matrix_a->dist == matrix_b->dist);

#pragma omp parallel for schedule(dynamic)
for (int ishard = 0; ishard < matrix_a->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix_a); ishard++) {
dbm_shard_copy(&matrix_a->shards[ishard], &matrix_b->shards[ishard]);
}
}
Expand Down Expand Up @@ -125,7 +123,7 @@ void dbm_redistribute(const dbm_matrix_t *matrix, dbm_matrix_t *redist) {
// 1st pass: Compute send_count.
int send_count[nranks];
memset(send_count, 0, nranks * sizeof(int));
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
dbm_shard_t *shard = &matrix->shards[ishard];
for (int iblock = 0; iblock < shard->nblocks; iblock++) {
const dbm_block_t *blk = &shard->blocks[iblock];
Expand Down Expand Up @@ -157,7 +155,7 @@ void dbm_redistribute(const dbm_matrix_t *matrix, dbm_matrix_t *redist) {
// 2nd pass: Fill send_data.
int send_data_positions[nranks];
memcpy(send_data_positions, send_displ, nranks * sizeof(int));
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
dbm_shard_t *shard = &matrix->shards[ishard];
for (int iblock = 0; iblock < shard->nblocks; iblock++) {
const dbm_block_t *blk = &shard->blocks[iblock];
Expand Down Expand Up @@ -214,7 +212,7 @@ void dbm_get_block_p(dbm_matrix_t *matrix, const int row, const int col,
*col_size = matrix->col_sizes[col];
*block = NULL;

const int ishard = row % matrix->nshards;
const int ishard = dbm_get_shard_index(matrix, row, col);
const dbm_shard_t *shard = &matrix->shards[ishard];
dbm_block_t *blk = dbm_shard_lookup(shard, row, col);
if (blk != NULL) {
Expand All @@ -236,7 +234,7 @@ void dbm_put_block(dbm_matrix_t *matrix, const int row, const int col,
const int col_size = matrix->col_sizes[col];
const int block_size = row_size * col_size;

const int ishard = row % matrix->nshards;
const int ishard = dbm_get_shard_index(matrix, row, col);
dbm_shard_t *shard = &matrix->shards[ishard];
omp_set_lock(&shard->lock);
dbm_block_t *blk =
Expand All @@ -260,7 +258,7 @@ void dbm_clear(dbm_matrix_t *matrix) {
assert(omp_get_num_threads() == 1);

#pragma omp parallel for schedule(dynamic)
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
dbm_shard_t *shard = &matrix->shards[ishard];
shard->nblocks = 0;
shard->data_size = 0;
Expand All @@ -283,7 +281,7 @@ void dbm_filter(dbm_matrix_t *matrix, const double eps) {
}

#pragma omp parallel for schedule(dynamic)
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
dbm_shard_t *shard = &matrix->shards[ishard];
const int old_nblocks = shard->nblocks;
shard->nblocks = 0;
Expand Down Expand Up @@ -330,7 +328,7 @@ void dbm_reserve_blocks(dbm_matrix_t *matrix, const int nblocks,
"Please call dbm_reserve_blocks within an OpenMP parallel region.");
const int my_rank = matrix->dist->my_rank;
for (int i = 0; i < nblocks; i++) {
const int ishard = rows[i] % matrix->nshards;
const int ishard = dbm_get_shard_index(matrix, rows[i], cols[i]);
dbm_shard_t *shard = &matrix->shards[ishard];
omp_set_lock(&shard->lock);
assert(0 <= rows[i] && rows[i] < matrix->nrows);
Expand All @@ -345,7 +343,7 @@ void dbm_reserve_blocks(dbm_matrix_t *matrix, const int nblocks,
#pragma omp barrier

#pragma omp for schedule(dynamic)
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
dbm_shard_t *shard = &matrix->shards[ishard];
dbm_shard_allocate_promised_blocks(shard);
}
Expand All @@ -366,7 +364,7 @@ void dbm_scale(dbm_matrix_t *matrix, const double alpha) {
}

#pragma omp parallel for schedule(dynamic)
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
dbm_shard_t *shard = &matrix->shards[ishard];
for (int i = 0; i < shard->data_size; i++) {
shard->data[i] *= alpha;
Expand All @@ -382,7 +380,7 @@ void dbm_zero(dbm_matrix_t *matrix) {
assert(omp_get_num_threads() == 1);

#pragma omp parallel for schedule(dynamic)
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
dbm_shard_t *shard = &matrix->shards[ishard];
memset(shard->data, 0, shard->data_size * sizeof(double));
}
Expand All @@ -394,11 +392,10 @@ void dbm_zero(dbm_matrix_t *matrix) {
******************************************************************************/
void dbm_add(dbm_matrix_t *matrix_a, const dbm_matrix_t *matrix_b) {
assert(omp_get_num_threads() == 1);
assert(matrix_a->nshards == matrix_b->nshards);
assert(matrix_a->dist == matrix_b->dist);

#pragma omp parallel for schedule(dynamic)
for (int ishard = 0; ishard < matrix_b->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix_b); ishard++) {
dbm_shard_t *shard_a = &matrix_a->shards[ishard];
const dbm_shard_t *shard_b = &matrix_b->shards[ishard];
for (int iblock = 0; iblock < shard_b->nblocks; iblock++) {
Expand Down Expand Up @@ -433,7 +430,7 @@ void dbm_iterator_start(dbm_iterator_t **iter_out, const dbm_matrix_t *matrix) {
iter->matrix = matrix;
iter->next_block = 0;
iter->next_shard = omp_get_thread_num();
while (iter->next_shard < matrix->nshards &&
while (iter->next_shard < dbm_get_num_shards(matrix) &&
matrix->shards[iter->next_shard].nblocks == 0) {
iter->next_shard += omp_get_num_threads();
}
Expand All @@ -447,7 +444,8 @@ void dbm_iterator_start(dbm_iterator_t **iter_out, const dbm_matrix_t *matrix) {
******************************************************************************/
int dbm_iterator_num_blocks(const dbm_iterator_t *iter) {
int num_blocks = 0;
for (int ishard = omp_get_thread_num(); ishard < iter->matrix->nshards;
for (int ishard = omp_get_thread_num();
ishard < dbm_get_num_shards(iter->matrix);
ishard += omp_get_num_threads()) {
num_blocks += iter->matrix->shards[ishard].nblocks;
}
Expand All @@ -459,7 +457,7 @@ int dbm_iterator_num_blocks(const dbm_iterator_t *iter) {
* \author Ole Schuett
******************************************************************************/
bool dbm_iterator_blocks_left(const dbm_iterator_t *iter) {
return iter->next_shard < iter->matrix->nshards;
return iter->next_shard < dbm_get_num_shards(iter->matrix);
}

/*******************************************************************************
Expand All @@ -469,7 +467,7 @@ bool dbm_iterator_blocks_left(const dbm_iterator_t *iter) {
void dbm_iterator_next_block(dbm_iterator_t *iter, int *row, int *col,
double **block, int *row_size, int *col_size) {
const dbm_matrix_t *matrix = iter->matrix;
assert(iter->next_shard < matrix->nshards);
assert(iter->next_shard < dbm_get_num_shards(matrix));
const dbm_shard_t *shard = &matrix->shards[iter->next_shard];
assert(iter->next_block < shard->nblocks);
dbm_block_t *blk = &shard->blocks[iter->next_block];
Expand All @@ -484,7 +482,7 @@ void dbm_iterator_next_block(dbm_iterator_t *iter, int *row, int *col,
if (iter->next_block >= shard->nblocks) {
// Advance to the next non-empty shard...
iter->next_shard += omp_get_num_threads();
while (iter->next_shard < matrix->nshards &&
while (iter->next_shard < dbm_get_num_shards(matrix) &&
matrix->shards[iter->next_shard].nblocks == 0) {
iter->next_shard += omp_get_num_threads();
}
Expand All @@ -504,7 +502,7 @@ void dbm_iterator_stop(dbm_iterator_t *iter) { free(iter); }
******************************************************************************/
double dbm_checksum(const dbm_matrix_t *matrix) {
double checksum = 0.0;
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
const dbm_shard_t *shard = &matrix->shards[ishard];
for (int i = 0; i < shard->data_size; i++) {
checksum += shard->data[i] * shard->data[i];
Expand All @@ -520,7 +518,7 @@ double dbm_checksum(const dbm_matrix_t *matrix) {
******************************************************************************/
double dbm_maxabs(const dbm_matrix_t *matrix) {
double maxabs = 0.0;
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
dbm_shard_t *shard = &matrix->shards[ishard];
for (int i = 0; i < shard->data_size; i++) {
maxabs = fmax(maxabs, fabs(shard->data[i]));
Expand All @@ -542,7 +540,7 @@ const char *dbm_get_name(const dbm_matrix_t *matrix) { return matrix->name; }
******************************************************************************/
int dbm_get_nze(const dbm_matrix_t *matrix) {
int nze = 0;
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
nze += matrix->shards[ishard].data_size;
}
return nze;
Expand All @@ -554,7 +552,7 @@ int dbm_get_nze(const dbm_matrix_t *matrix) {
******************************************************************************/
int dbm_get_num_blocks(const dbm_matrix_t *matrix) {
int nblocks = 0;
for (int ishard = 0; ishard < matrix->nshards; ishard++) {
for (int ishard = 0; ishard < dbm_get_num_shards(matrix); ishard++) {
nblocks += matrix->shards[ishard].nblocks;
}
return nblocks;
Expand Down
20 changes: 19 additions & 1 deletion src/dbm/dbm_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ typedef struct {
int *row_sizes;
int *col_sizes;

int nshards;
dbm_shard_t *shards;
} dbm_matrix_t;

Expand Down Expand Up @@ -226,6 +225,25 @@ int dbm_get_stored_coordinates(const dbm_matrix_t *matrix, const int row,
******************************************************************************/
const dbm_distribution_t *dbm_get_distribution(const dbm_matrix_t *matrix);

/*******************************************************************************
* \brief Internal routine that returns the number of shards for given matrix.
* \author Ole Schuett
******************************************************************************/
static inline int dbm_get_num_shards(const dbm_matrix_t *matrix) {
return matrix->dist->rows.nshards * matrix->dist->cols.nshards;
}

/*******************************************************************************
* \brief Internal routine for getting a block's shard index.
* \author Ole Schuett
******************************************************************************/
static inline int dbm_get_shard_index(const dbm_matrix_t *matrix, const int row,
const int col) {
const int shard_row = row % matrix->dist->rows.nshards;
const int shard_col = col % matrix->dist->cols.nshards;
return shard_row * matrix->dist->cols.nshards + shard_col;
}

#endif

// EOF

0 comments on commit 245edd6

Please sign in to comment.