Skip to content

Commit

Permalink
[mlir][sparse] Add readCOOElement for reading a sparse tensor element…
Browse files Browse the repository at this point in the history
… from files.

Use the routine for openSparseTensorCOO and getSparseTensorReaderNext.

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D135732
  • Loading branch information
bixia1 committed Oct 16, 2022
1 parent dd3d8dd commit d18bfb2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 49 deletions.
107 changes: 64 additions & 43 deletions mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,44 @@
namespace mlir {
namespace sparse_tensor {

namespace detail {

template <typename T>
struct is_complex final : public std::false_type {};

template <typename T>
struct is_complex<std::complex<T>> final : public std::true_type {};

/// Reads an element of a non-complex type for the current indices in
/// coordinate scheme.
template <typename V>
inline std::enable_if_t<!is_complex<V>::value, V>
readCOOValue(char **linePtr, bool is_pattern) {
// The external formats always store these numerical values with the type
// double, but we cast these values to the sparse tensor object type.
// For a pattern tensor, we arbitrarily pick the value 1 for all entries.
return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
}

/// Reads an element of a complex type for the current indices in
/// coordinate scheme.
template <typename V>
inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr,
bool is_pattern) {
// Read two values to make a complex. The external formats always store
// numerical values with the type double, but we cast these values to the
// sparse tensor object type. For a pattern tensor, we arbitrarily pick the
// value 1 for all entries.
double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
// Avoiding brace-notation since that forbids narrowing to `float`.
return V(re, im);
}

} // namespace detail

//===----------------------------------------------------------------------===//

// TODO: benchmark whether to keep various methods inline vs moving them
// off to the cpp file.

Expand Down Expand Up @@ -132,6 +170,31 @@ class SparseTensorReader final {
/// valid after parsing the header.
void assertMatchesShape(uint64_t rank, const uint64_t *shape) const;

/// Reads a sparse tensor element from the next line in the input file and
/// returns the value of the element. Stores the coordinates of the element
/// to the `indices` array.
template <typename V>
V readCOOElement(uint64_t rank, uint64_t *indices,
const uint64_t *perm = nullptr) {
assert(rank == getRank() && "Rank mismatch");
char *linePtr = readLine();
if (perm)
for (uint64_t r = 0; r < rank; ++r) {
// Parse the 1-based index.
uint64_t idx = strtoul(linePtr, &linePtr, 10);
// Store the 0-based index.
indices[perm[r]] = idx - 1;
}
else
for (uint64_t r = 0; r < rank; ++r) {
// Parse the 1-based index.
uint64_t idx = strtoul(linePtr, &linePtr, 10);
// Store the 0-based index.
indices[r] = idx - 1;
}
return detail::readCOOValue<V>(&linePtr, isPattern());
}

private:
/// Reads the MME header of a general sparse matrix of type real.
void readMMEHeader();
Expand All @@ -152,41 +215,6 @@ class SparseTensorReader final {
};

//===----------------------------------------------------------------------===//
namespace detail {

template <typename T>
struct is_complex final : public std::false_type {};

template <typename T>
struct is_complex<std::complex<T>> final : public std::true_type {};

/// Reads an element of a non-complex type for the current indices in
/// coordinate scheme.
template <typename V>
inline std::enable_if_t<!is_complex<V>::value, V>
readCOOValue(char **linePtr, bool is_pattern) {
// The external formats always store these numerical values with the type
// double, but we cast these values to the sparse tensor object type.
// For a pattern tensor, we arbitrarily pick the value 1 for all entries.
return is_pattern ? 1.0 : strtod(*linePtr, linePtr);
}

/// Reads an element of a complex type for the current indices in
/// coordinate scheme.
template <typename V>
inline std::enable_if_t<is_complex<V>::value, V> readCOOValue(char **linePtr,
bool is_pattern) {
// Read two values to make a complex. The external formats always store
// numerical values with the type double, but we cast these values to the
// sparse tensor object type. For a pattern tensor, we arbitrarily pick the
// value 1 for all entries.
double re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
double im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
// Avoiding brace-notation since that forbids narrowing to `float`.
return V(re, im);
}

} // namespace detail

/// Reads a sparse tensor with the given filename into a memory-resident
/// sparse tensor in coordinate scheme.
Expand All @@ -211,14 +239,7 @@ openSparseTensorCOO(const char *filename, uint64_t rank, const uint64_t *shape,
// Read all nonzero elements.
std::vector<uint64_t> indices(rank);
for (uint64_t k = 0; k < nnz; ++k) {
char *linePtr = stfile.readLine();
for (uint64_t r = 0; r < rank; ++r) {
// Parse the 1-based index.
uint64_t idx = strtoul(linePtr, &linePtr, 10);
// Add the 0-based index.
indices[perm[r]] = idx - 1;
}
const V value = detail::readCOOValue<V>(&linePtr, stfile.isPattern());
const V value = stfile.readCOOElement<V>(rank, indices.data(), perm);
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
coo->add(indices, value);
// We currently chose to deal with symmetric matrices by fully
Expand Down
7 changes: 1 addition & 6 deletions mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,13 +626,8 @@ void delSparseTensorReader(void *p) {
index_type *indices = iref->data + iref->offset; \
SparseTensorReader *stfile = static_cast<SparseTensorReader *>(p); \
index_type rank = stfile->getRank(); \
char *linePtr = stfile->readLine(); \
for (index_type r = 0; r < rank; ++r) { \
uint64_t idx = strtoul(linePtr, &linePtr, 10); \
indices[r] = idx - 1; \
} \
V *value = vref->data + vref->offset; \
*value = detail::readCOOValue<V>(&linePtr, stfile->isPattern()); \
*value = stfile->readCOOElement<V>(rank, indices); \
}
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
#undef IMPL_GETNEXT
Expand Down

0 comments on commit d18bfb2

Please sign in to comment.