Skip to content

Commit

Permalink
[mlir][sparse] use consistent type for COO object and sparse tensor s…
Browse files Browse the repository at this point in the history
…torage

There was a slightly mismatch between the double COO and actual numerical
type in the final sparse tensor storage (due to external formats always
using double). This minor revision removes that inconsistency by using a
properly typed COO and casting during the "add" method instead. This also
prepares alternative ways of initializing the COO object.

Reviewed By: gussmith23

Differential Revision: https://reviews.llvm.org/D107310
  • Loading branch information
aartbik committed Aug 2, 2021
1 parent 65e9d7e commit 52c87e0
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions mlir/lib/ExecutionEngine/SparseUtils.cpp
Expand Up @@ -151,12 +151,12 @@ class SparseTensorStorageBase {
/// each differently annotated sparse tensor, this method provides a convenient
/// "one-size-fits-all" solution that simply takes an input tensor and
/// annotations to implement all required setup in a general manner.
template <typename P, typename I, typename V, typename Ve>
template <typename P, typename I, typename V>
class SparseTensorStorage : public SparseTensorStorageBase {
public:
/// Constructs sparse tensor storage scheme following the given
/// per-rank dimension dense/sparse annotations.
SparseTensorStorage(SparseTensor<Ve> *tensor, uint8_t *sparsity)
SparseTensorStorage(SparseTensor<V> *tensor, uint8_t *sparsity)
: sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) {
// Provide hints on capacity.
// TODO: needs fine-tuning based on sparsity
Expand Down Expand Up @@ -191,14 +191,23 @@ class SparseTensorStorage : public SparseTensorStorageBase {
}
void getValues(std::vector<V> **out) override { *out = &values; }

// Factory method.
static SparseTensorStorage<P, I, V> *newSparseTensor(SparseTensor<V> *t,
uint8_t *s) {
t->sort(); // sort lexicographically
SparseTensorStorage<P, I, V> *n = new SparseTensorStorage<P, I, V>(t, s);
delete t;
return n;
}

private:
/// Initializes sparse tensor storage scheme from a memory-resident
/// representation of an external sparse tensor. This method prepares
/// the pointers and indices arrays under the given per-rank dimension
/// dense/sparse annotations.
void traverse(SparseTensor<Ve> *tensor, uint8_t *sparsity, uint64_t lo,
void traverse(SparseTensor<V> *tensor, uint8_t *sparsity, uint64_t lo,
uint64_t hi, uint64_t d) {
const std::vector<Element<Ve>> &elements = tensor->getElements();
const std::vector<Element<V>> &elements = tensor->getElements();
// Once dimensions are exhausted, insert the numerical values.
if (d == getRank()) {
values.push_back(lo < hi ? elements[lo].value : 0);
Expand Down Expand Up @@ -321,9 +330,9 @@ static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) {
}

/// Reads a sparse tensor with the given filename into a memory-resident
/// sparse tensor in coordinate scheme. The external formats always store
/// the numerical values with the type double.
static SparseTensor<double> *openTensor(char *filename, uint64_t *perm) {
/// sparse tensor in coordinate scheme.
template <typename V>
static SparseTensor<V> *openTensor(char *filename, uint64_t *perm) {
// Open the file.
FILE *file = fopen(filename, "r");
if (!file) {
Expand All @@ -347,7 +356,7 @@ static SparseTensor<double> *openTensor(char *filename, uint64_t *perm) {
std::vector<uint64_t> indices(rank);
for (uint64_t r = 0; r < rank; r++)
indices[perm[r]] = idata[2 + r];
SparseTensor<double> *tensor = new SparseTensor<double>(indices, nnz);
SparseTensor<V> *tensor = new SparseTensor<V>(indices, nnz);
// Read all nonzero elements.
for (uint64_t k = 0; k < nnz; k++) {
uint64_t idx = -1;
Expand All @@ -359,28 +368,17 @@ static SparseTensor<double> *openTensor(char *filename, uint64_t *perm) {
// Add 0-based index.
indices[perm[r]] = idx - 1;
}
// The external formats always store the numerical values with the type
// double, but we cast these values to the sparse tensor object type.
double value;
if (fscanf(file, "%lg\n", &value) != 1) {
fprintf(stderr, "Cannot find next value in %s\n", filename);
exit(1);
}
tensor->add(indices, value);
}
// Close the file and return sorted tensor.
// Close the file and return tensor.
fclose(file);
tensor->sort(); // sort lexicographically
return tensor;
}

/// Templated reader.
template <typename P, typename I, typename V>
void *newSparseTensor(char *filename, uint8_t *sparsity, uint64_t *perm,
uint64_t size) {
SparseTensor<double> *t = openTensor(filename, perm);
assert(size == t->getRank()); // sparsity array must match rank
SparseTensorStorageBase *tensor =
new SparseTensorStorage<P, I, V, double>(t, sparsity);
delete t;
return tensor;
}

Expand Down Expand Up @@ -419,8 +417,11 @@ char *getTensorFilename(uint64_t id) {
}

#define CASE(p, i, v, P, I, V) \
if (ptrTp == (p) && indTp == (i) && valTp == (v)) \
return newSparseTensor<P, I, V>(filename, sparsity, perm, asize)
if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \
SparseTensor<V> *tensor = openTensor<V>(filename, perm); \
assert(asize == tensor->getRank()); \
return SparseTensorStorage<P, I, V>::newSparseTensor(tensor, sparsity); \
}

#define IMPL1(RET, NAME, TYPE, LIB) \
RET NAME(void *tensor) { \
Expand Down

0 comments on commit 52c87e0

Please sign in to comment.