Skip to content

Commit

Permalink
DBM: Implement retain_sparsity for dbm_multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
oschuett committed Mar 8, 2022
1 parent 95ff93a commit 76175a8
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/dbm/dbm_multiply.c
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ static void multiply_packs(const bool transa, const bool transb,
const dbm_pack_t *pack_b,
const dbm_matrix_t *matrix_a,
const dbm_matrix_t *matrix_b, dbm_matrix_t *matrix_c,
const bool retain_sparsity,
const float *rows_left_max_eps, int64_t *flop,
backend_context_t *ctx) {

Expand Down Expand Up @@ -208,14 +209,19 @@ static void multiply_packs(const bool transa, const bool transb,
assert(n == col_size_c);
assert(k == row_size_right);

// Get C block.
dbm_block_t *blk_c = dbm_shard_lookup(shard_c, row_left, col_right);
if (blk_c == NULL && retain_sparsity) {
continue;
} else if (blk_c == NULL) {
blk_c =
dbm_shard_promise_new_block(shard_c, row_left, col_right, m * n);
}

// Count flops.
assert(m * n * k > 0);
flop_sum += 2 * m * n * k;

// Get C block.
dbm_block_t *blk_c =
dbm_shard_get_or_promise_block(shard_c, row_left, col_right, m * n);

// Invalidate norm of C block because its data is going to change.
blk_c->norm = -1.0;

Expand Down Expand Up @@ -253,7 +259,6 @@ void dbm_multiply(const bool transa, const bool transb, const double alpha,
int64_t *flop) {

assert(omp_get_num_threads() == 1);
assert(retain_sparsity == false); // TODO implement

// Denote left/right to matrices a/b after possible transpose.
const int nrows_left = (transa) ? matrix_a->ncols : matrix_a->nrows;
Expand Down Expand Up @@ -287,7 +292,7 @@ void dbm_multiply(const bool transa, const bool transb, const double alpha,
while (dbm_comm_iterator_next(iter, &pack_a, &pack_b)) {
backend_upload_packs(pack_a, pack_b, ctx);
multiply_packs(transa, transb, alpha, pack_a, pack_b, matrix_a, matrix_b,
matrix_c, rows_left_max_eps, flop, ctx);
matrix_c, retain_sparsity, rows_left_max_eps, flop, ctx);
}

// Start downloading matrix_c from the GPU.
Expand Down

0 comments on commit 76175a8

Please sign in to comment.