Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up Tensor and Shape stringification #808

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion flashlight/fl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ target_sources(
${CMAKE_CURRENT_LIST_DIR}/Random.cpp
${CMAKE_CURRENT_LIST_DIR}/Shape.cpp
${CMAKE_CURRENT_LIST_DIR}/TensorBackend.cpp
${CMAKE_CURRENT_LIST_DIR}/TensorBase.cpp
${CMAKE_CURRENT_LIST_DIR}/TensorBase.cpp
${CMAKE_CURRENT_LIST_DIR}/TensorAdapter.cpp
${CMAKE_CURRENT_LIST_DIR}/TensorExtension.cpp
${CMAKE_CURRENT_LIST_DIR}/Types.cpp
)
9 changes: 5 additions & 4 deletions flashlight/fl/tensor/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

namespace fl {

range::range(idx idx) : range(0, idx) {}
range::range(const idx& i) : range(0, i) {}

range::range(idx start, idx end) : range(start, end, /* stride */ 1) {}
range::range(const idx& start, const idx& end)
: range(start, end, /* stride */ 1) {}

range::range(idx start, idx end, Dim stride)
range::range(const idx& start, const idx& end, const Dim stride)
: // fl::end decays to int
start_(std::visit([](Dim idx) -> Dim { return idx; }, start)),
start_(std::visit([](const Dim idx) -> Dim { return idx; }, start)),
// fl::end --> -1, else idx as Dim
end_(
std::holds_alternative<fl::end_t>(end)
Expand Down
6 changes: 3 additions & 3 deletions flashlight/fl/tensor/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ class range {
/**
* Construct a range with the indices [0, idx) (i.e. [0, idx - 1])
*/
explicit range(idx idx);
explicit range(const idx& idx);

/**
* Construct a range with the indices [start, end) (i.e. [start, end - 1])
*/
range(idx start, idx end);
range(const idx& start, const idx& end);

/**
* Construct a range with the indices [start, end) (i.e. [start, end - 1])
* with the given stride.
*/
range(idx start, idx end, Dim stride);
range(const idx& start, const idx& end, const Dim stride);

Dim start() const;
Dim end() const;
Expand Down
43 changes: 29 additions & 14 deletions flashlight/fl/tensor/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <algorithm>
#include <limits>
#include <numeric>
#include <sstream>
#include <stdexcept>

namespace fl {
Expand All @@ -19,6 +20,16 @@ Shape::Shape(std::initializer_list<Dim> d) : Shape(std::vector<Dim>(d)) {}

const Dim kEmptyShapeNumberOfElements = 1;

void Shape::checkDimsOrThrow(const size_t dim) const {
if (dim > ndim() - 1) {
std::stringstream ss;
ss << "Shape index " << std::to_string(dim)
<< " out of bounds for shape with " << std::to_string(dims_.size())
<< " dimensions.";
throw std::invalid_argument(ss.str());
}
}

Dim Shape::elements() const {
if (dims_.size() == 0) {
return kEmptyShapeNumberOfElements;
Expand All @@ -31,19 +42,17 @@ size_t Shape::ndim() const {
}

Dim Shape::dim(const size_t dim) const {
if (dim >= dims_.size()) {
throw std::invalid_argument(
"fl::Shape::dim - passed dimension is larger than "
"the number of dimensions in the shape");
}
checkDimsOrThrow(dim);
return dims_[dim];
}

Dim& Shape::operator[](const size_t dim) {
checkDimsOrThrow(dim);
return dims_[dim];
}

const Dim& Shape::operator[](const size_t dim) const {
checkDimsOrThrow(dim);
return dims_[dim];
}

Expand All @@ -64,15 +73,6 @@ bool Shape::operator!=(const std::initializer_list<Dim>& other) const {
return !(this->operator==(other));
}

std::ostream& operator<<(std::ostream& ostr, const Shape& s) {
ostr << "(";
for (size_t i = 0; i < s.ndim(); ++i) {
ostr << s.dim(i) << (i == s.ndim() - 1 ? "" : ", ");
}
ostr << ")";
return ostr;
}

const std::vector<Dim>& Shape::get() const {
return dims_;
}
Expand All @@ -81,4 +81,19 @@ std::vector<Dim>& Shape::get() {
return dims_;
};

std::string Shape::toString() const {
std::stringstream ss;
ss << "(";
for (size_t i = 0; i < ndim(); ++i) {
ss << dim(i) << (i == ndim() - 1 ? "" : ", ");
}
ss << ")";
return ss.str();
}

std::ostream& operator<<(std::ostream& ostr, const Shape& s) {
ostr << s.toString();
return ostr;
}

} // namespace fl
11 changes: 11 additions & 0 deletions flashlight/fl/tensor/Shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class Shape {
// {} is a scalar shape.
std::vector<Dim> dims_;

/**
* Check if a dimension is valid (i.e. in bounds) given the current size of
* the shape. If not valid, throws an exception.
*/
void checkDimsOrThrow(const size_t dim) const;

public:
Shape() = default;
~Shape() = default;
Expand Down Expand Up @@ -107,6 +113,11 @@ class Shape {
*/
const std::vector<Dim>& get() const;
std::vector<Dim>& get();

/**
* Returns a string representation of the Shape
*/
std::string toString() const;
};

/**
Expand Down
6 changes: 6 additions & 0 deletions flashlight/fl/tensor/TensorAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ class TensorAdapterBase {
*/
virtual void* getContext() = 0;

/**
* Return a string representation of a Tensor. Not intended to be portable
* across backends.
*/
virtual std::string toString() = 0;

/**
* Write a string representation of a tensor to an output stream.
*/
Expand Down
1 change: 1 addition & 0 deletions flashlight/fl/tensor/TensorBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

namespace fl {
namespace detail {

bool areBackendsEqual(const Tensor& a, const Tensor& b) {
return a.backendType() == b.backendType();
}
Expand Down
35 changes: 35 additions & 0 deletions flashlight/fl/tensor/TensorBackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
#pragma once

#include <memory>
#include <stdexcept>
#include <type_traits>
#include <unordered_map>
#include <utility>

#include "flashlight/fl/tensor/TensorBase.h"
#include "flashlight/fl/tensor/TensorExtension.h"

namespace fl {

Expand All @@ -31,6 +34,7 @@ class TensorBackend {
public:
TensorBackend() = default;
virtual ~TensorBackend() = default;
virtual TensorBackendType backendType() const = 0;

/* -------------------------- Compute Functions -------------------------- */
virtual void sync() = 0;
Expand Down Expand Up @@ -122,6 +126,12 @@ class TensorBackend {
const SortMode sortMode) = 0;
virtual Tensor
sort(const Tensor& input, const Dim axis, const SortMode sortMode) = 0;
virtual void sort(
Tensor& values,
Tensor& indices,
const Tensor& input,
const Dim axis,
const SortMode sortMode) = 0;
virtual Tensor
argsort(const Tensor& input, const Dim axis, const SortMode sortMode) = 0;

Expand Down Expand Up @@ -232,6 +242,31 @@ class TensorBackend {

/************************** Utils ***************************/
virtual void print(const Tensor& tensor) = 0;

/********************* Tensor Extensions **********************/
template <typename T>
T& getExtension() {
static_assert(
std::is_base_of<TensorExtensionBase, T>::value,
"TensorBackend::getExtension<T>() called with type T "
"that is not derived from TensorExtensionBase.");

TensorExtensionType e = T::getExtensionType();

// If an extension isn't present, instantiate it via its registered
// creation function - only do this once per extension.
if (extensions_.find(e) == extensions_.end()) {
auto& creationFunc =
detail::TensorExtensionRegistrar::getInstance()
.getTensorExtensionCreationFunc(this->backendType(), e);
extensions_.emplace(e, creationFunc());
}
return *(static_cast<T*>(extensions_.at(e).get()));
}

protected:
std::unordered_map<TensorExtensionType, std::unique_ptr<TensorExtensionBase>>
extensions_;
};

/**
Expand Down
27 changes: 26 additions & 1 deletion flashlight/fl/tensor/TensorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ void* Tensor::getContext() const {
return impl_->getContext();
}

std::string Tensor::toString() const {
return impl_->toString();
}

std::ostream& Tensor::operator<<(std::ostream& ostr) const {
return impl_->operator<<(ostr);
}
Expand Down Expand Up @@ -501,7 +505,7 @@ void topk(
const Tensor& input,
const unsigned k,
const Dim axis,
const SortMode sortMode) {
const SortMode sortMode /* = SortMode::Descending */) {
FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input);
input.backend().topk(values, indices, input, k, axis, sortMode);
}
Expand All @@ -510,6 +514,15 @@ Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) {
return input.backend().sort(input, axis, sortMode);
}

void sort(
Tensor& values,
Tensor& indices,
const Tensor& input,
const Dim axis,
const SortMode sortMode /* = SortMode::Descending */) {
return values.backend().sort(values, indices, input, axis, sortMode);
}

Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) {
return input.backend().argsort(input, axis, sortMode);
}
Expand Down Expand Up @@ -611,6 +624,10 @@ Tensor power(const Tensor& lhs, const double& rhs) {
return lhs.backend().power(lhs, full(lhs.shape(), rhs));
}

Tensor power(const double& lhs, const Tensor& rhs) {
return rhs.backend().power(full(rhs.shape(), lhs), rhs);
}

/******************************* BLAS ********************************/
Tensor matmul(
const Tensor& lhs,
Expand Down Expand Up @@ -764,4 +781,12 @@ bool allClose(
absTolerance;
}

namespace detail {

bool areTensorTypesEqual(const Tensor& a, const Tensor& b) {
return a.type() == b.type();
}

} // namespace detail

} // namespace fl