Skip to content

Commit

Permalink
Clean up Tensor and Shape stringification (flashlight#808)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: flashlight#808

See title. Add `Tensor::toString()` and `Shape::toString()` explicitly and use in `operator<<` for `ArrayFireTensor`.

Reviewed By: benoitsteiner

Differential Revision: D33159894

fbshipit-source-id: bc6a32658191933b37662c44ac13561e1bd3e305
  • Loading branch information
jacobkahn authored and facebook-github-bot committed Jan 11, 2022
1 parent dbcd0f7 commit 55e8038
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 12 deletions.
24 changes: 15 additions & 9 deletions flashlight/fl/tensor/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,6 @@ bool Shape::operator!=(const std::initializer_list<Dim>& other) const {
return !(this->operator==(other));
}

std::ostream& operator<<(std::ostream& ostr, const Shape& s) {
ostr << "(";
for (size_t i = 0; i < s.ndim(); ++i) {
ostr << s.dim(i) << (i == s.ndim() - 1 ? "" : ", ");
}
ostr << ")";
return ostr;
}

const std::vector<Dim>& Shape::get() const {
return dims_;
}
Expand All @@ -90,4 +81,19 @@ std::vector<Dim>& Shape::get() {
return dims_;
};

std::string Shape::toString() const {
std::stringstream ss;
ss << "(";
for (size_t i = 0; i < ndim(); ++i) {
ss << dim(i) << (i == ndim() - 1 ? "" : ", ");
}
ss << ")";
return ss.str();
}

std::ostream& operator<<(std::ostream& ostr, const Shape& s) {
ostr << s.toString();
return ostr;
}

} // namespace fl
5 changes: 5 additions & 0 deletions flashlight/fl/tensor/Shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ class Shape {
*/
const std::vector<Dim>& get() const;
std::vector<Dim>& get();

/**
* Returns a string representation of the Shape
*/
std::string toString() const;
};

/**
Expand Down
6 changes: 6 additions & 0 deletions flashlight/fl/tensor/TensorAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ class TensorAdapterBase {
*/
virtual void* getContext() = 0;

/**
* Return a string representation of a Tensor. Not intended to be portable
* across backends.
*/
virtual std::string toString() = 0;

/**
* Write a string representation of a tensor to an output stream.
*/
Expand Down
4 changes: 4 additions & 0 deletions flashlight/fl/tensor/TensorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ void* Tensor::getContext() const {
return impl_->getContext();
}

std::string Tensor::toString() const {
return impl_->toString();
}

std::ostream& Tensor::operator<<(std::ostream& ostr) const {
return impl_->operator<<(ostr);
}
Expand Down
11 changes: 10 additions & 1 deletion flashlight/fl/tensor/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,15 @@ class Tensor {
*/
void* getContext() const;

/**
* Returns a string representation of a Tensor. NOTE: This is
* backend-dependent. See Flashlight's serialization utilities for ways to
* serialize Tensors that are portable across Tensor backends.
*
* @return a string representation of the Tensor.
*/
std::string toString() const;

/**
* Write a string representation of a tensor to an output stream.
*/
Expand Down Expand Up @@ -1391,7 +1400,7 @@ Tensor all(
/**
* Write a string representation of a tensor to an output stream.
*/
std::ostream& operator<<(std::ostream& ostr, const Tensor& s);
std::ostream& operator<<(std::ostream& ostr, const Tensor& t);

/**
* Print a string representation of a tensor to standard out.
Expand Down
6 changes: 5 additions & 1 deletion flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,12 @@ void* ArrayFireTensor::getContext() {
return nullptr; // noop
}

std::string ArrayFireTensor::toString() {
return std::string(af::toString("ArrayFireTensor", getHandle()));
}

std::ostream& ArrayFireTensor::operator<<(std::ostream& ostr) {
ostr << "ArrayFireTensor " << std::string(af::toString("", getHandle()));
ostr << this->toString();
return ostr;
}

Expand Down
1 change: 1 addition & 0 deletions flashlight/fl/tensor/backend/af/ArrayFireTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class ArrayFireTensor : public TensorAdapterBase {
Tensor asContiguousTensor() override;
void setContext(void* context) override; // noop
void* getContext() override; // noop
std::string toString() override;
std::ostream& operator<<(std::ostream& ostr) override;

/******************** Assignment Operators ********************/
Expand Down
15 changes: 15 additions & 0 deletions flashlight/fl/test/tensor/ShapeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,18 @@ TEST(ShapeTest, Indexing) {
ASSERT_EQ(a[3], 2);
ASSERT_THROW(a[4], std::invalid_argument);
}

TEST(ShapeTest, string) {
auto checkShapeStrEqual = [](const Shape& s, const std::string& str) -> void {
auto sStr = s.toString();
ASSERT_EQ(sStr, str);
std::stringstream ss;
ss << sStr;
ASSERT_EQ(sStr, ss.str());
};

checkShapeStrEqual(Shape({3, 4, 7, 9}), "(3, 4, 7, 9)");
checkShapeStrEqual(Shape({}), "()");
checkShapeStrEqual(Shape({0}), "(0)");
checkShapeStrEqual(Shape({7, 7, 7, 7, 7, 7, 7}), "(7, 7, 7, 7, 7, 7, 7)");
}
4 changes: 3 additions & 1 deletion flashlight/fl/test/tensor/TensorBaseTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ TEST(TensorBaseTest, fromScalar) {
ASSERT_EQ(a.shape(), Shape({}));
}

TEST(TensorBaseTest, ostream) {
TEST(TensorBaseTest, string) {
// Different backends might print tensors differently - check for consistency
// across two identical tensors
auto a = fl::full({3, 4, 5}, 6.);
auto b = fl::full({3, 4, 5}, 6.);
ASSERT_EQ(a.toString(), b.toString());

std::stringstream ssa, ssb;
ssa << a;
ssb << b;
Expand Down

0 comments on commit 55e8038

Please sign in to comment.