Skip to content

Commit

Permalink
Make tensor class dimension independent.
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-camp committed Oct 6, 2020
1 parent c6b0391 commit 0455091
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 454 deletions.
42 changes: 22 additions & 20 deletions include/emitc/emitc_mhlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ inline Src logical_xor(Src x, Src y) {
// BroadcastInDimOp
template <typename Dest, typename Src>
inline Dest
broadcast_in_dim(Src x, Tensor1D<int64_t, Src::rank> broadcast_dimensions) {
static_assert(Src::rank == 0, "Only 0-dim operand supported so far");
broadcast_in_dim(Src x, Tensor1D<int64_t, Src::rank()> broadcast_dimensions) {
static_assert(Src::rank() == 0, "Only 0-dim operand supported so far");

Dest z;
std::fill(z.begin(), z.end(), x[0]);
Expand All @@ -370,7 +370,7 @@ inline Dest concatenate(Src1 input1, Src... inputs) {
// concatenate all but the first input
// We need to build the correct return type for the rest of the inputs
using ET_Src = typename get_element_type<Src1>::type;
using Rest = typename concat<ET_Src, Dimension, Src...>::type;
using Rest = typename concat<Dimension, ET_Src, Src...>::type;
Rest rest = concatenate<Dimension, Rest, Src...>(inputs...);

Dest z;
Expand Down Expand Up @@ -409,7 +409,7 @@ inline Dest concatenate(Src1 input1, Src... inputs) {

// SliceOp
// Overload for 1d case
template <typename Dest, typename Src, IsTensor1D<Src> = true>
template <typename Dest, typename Src, IsTensorOfDim<1, Src> = true>
Dest slice(Src x, Tensor1D<int64_t, 1> start_indices,
Tensor1D<int64_t, 1> limit_indices, Tensor1D<int64_t, 1> strides) {
Dest z;
Expand All @@ -423,7 +423,7 @@ Dest slice(Src x, Tensor1D<int64_t, 1> start_indices,
}

// Overload for 2d case
template <typename Dest, typename Src, IsTensor2D<Src> = true>
template <typename Dest, typename Src, IsTensorOfDim<2, Src> = true>
Dest slice(Src x, Tensor1D<int64_t, 2> start_indices,
Tensor1D<int64_t, 2> limit_indices, Tensor1D<int64_t, 2> strides) {
Dest z;
Expand All @@ -440,14 +440,14 @@ Dest slice(Src x, Tensor1D<int64_t, 2> start_indices,

// DynamicSliceOp
// Overload for 1d case
template <typename Dest, typename Src, IsTensor1D<Src> = true>
template <typename Dest, typename Src, IsTensorOfDim<1, Src> = true>
Dest dynamic_slice(Src x, int64_t start_index,
Tensor1D<int64_t, 1> size_indices) {
auto clamp = [](int64_t value, int64_t minValue, int64_t maxValue) {
return std::max(minValue, std::min(maxValue, value));
};

int64_t dim_x = static_cast<int64_t>(Src::dimX);
int64_t dim_x = static_cast<int64_t>(Src::dim(0));
int64_t start_index_eff = clamp(start_index, 0, dim_x - size_indices[0]);
Tensor1D<int64_t, 1> start_indices{start_index_eff};
Tensor1D<int64_t, 1> limit_indices{start_index_eff + size_indices[0]};
Expand All @@ -457,15 +457,15 @@ Dest dynamic_slice(Src x, int64_t start_index,
}

// Overload for 2d case
template <typename Dest, typename Src, IsTensor2D<Src> = true>
template <typename Dest, typename Src, IsTensorOfDim<2, Src> = true>
Dest dynamic_slice(Src x, int64_t start_index_x, int64_t start_index_y,
Tensor1D<int64_t, 2> size_indices) {
auto clamp = [](int64_t value, int64_t minValue, int64_t maxValue) {
return std::max(minValue, std::min(maxValue, value));
};

int64_t dim_x = static_cast<int64_t>(Src::dimX);
int64_t dim_y = static_cast<int64_t>(Src::dimY);
int64_t dim_x = static_cast<int64_t>(Src::dim(0));
int64_t dim_y = static_cast<int64_t>(Src::dim(1));
int64_t start_index_x_eff = clamp(start_index_x, 0, dim_x - size_indices[0]);
int64_t start_index_y_eff = clamp(start_index_y, 0, dim_y - size_indices[1]);
Tensor1D<int64_t, 2> start_indices{start_index_x_eff, start_index_y_eff};
Expand All @@ -478,25 +478,25 @@ Dest dynamic_slice(Src x, int64_t start_index_x, int64_t start_index_y,

// DynamicUpdateSliceOp
// Overload for 1d case
template <typename Update, typename Src, IsTensor1D<Src> = true>
template <typename Update, typename Src, IsTensorOfDim<1, Src> = true>
Src dynamic_update_slice(Src x, Update update, int64_t start_index) {
auto clamp = [](int64_t value, int64_t minValue, int64_t maxValue) {
return std::max(minValue, std::min(maxValue, value));
};

Src z = x;

size_t start_index_eff = clamp(start_index, 0, Src::dimX - Update::dimX);
size_t start_index_eff = clamp(start_index, 0, Src::dim(0) - Update::dim(0));

for (size_t i = 0; i < Update::dimX; i++) {
for (size_t i = 0; i < Update::dim(0); i++) {
z(start_index_eff + i) = update(i);
}

return z;
}

// Overload for 2d case
template <typename Update, typename Src, IsTensor2D<Src> = true>
template <typename Update, typename Src, IsTensorOfDim<2, Src> = true>
Src dynamic_update_slice(Src x, Update update, int64_t start_index_x,
int64_t start_index_y) {
auto clamp = [](int64_t value, int64_t minValue, int64_t maxValue) {
Expand All @@ -505,11 +505,13 @@ Src dynamic_update_slice(Src x, Update update, int64_t start_index_x,

Src z = x;

size_t start_index_x_eff = clamp(start_index_x, 0, Src::dimX - Update::dimX);
size_t start_index_y_eff = clamp(start_index_y, 0, Src::dimY - Update::dimY);
size_t start_index_x_eff =
clamp(start_index_x, 0, Src::dim(0) - Update::dim(0));
size_t start_index_y_eff =
clamp(start_index_y, 0, Src::dim(1) - Update::dim(1));

for (size_t i = 0; i < Update::dimX; i++) {
for (size_t j = 0; j < Update::dimY; j++) {
for (size_t i = 0; i < Update::dim(0); i++) {
for (size_t j = 0; j < Update::dim(1); j++) {
z(start_index_x_eff + i, start_index_y_eff + j) = update(i, j);
}
}
Expand All @@ -527,7 +529,7 @@ inline Dest reshape(Src x) {
using ET_Dest = typename get_element_type<Dest>::type;

static_assert(std::is_same<ET_Src, ET_Dest>::value, "Element type mismatch");
static_assert(Src::size_ == Dest::size_, "Tensor size mismatch");
static_assert(Src::size() == Dest::size(), "Tensor size mismatch");

Dest z;

Expand All @@ -548,7 +550,7 @@ inline Src select(typename replace_element_type<bool, Src>::type pred,
Src on_true, Src on_false) {
Src z;

for (size_t i = 0; i < Src::size_; i++) {
for (size_t i = 0; i < Src::size(); i++) {
z[i] = pred[i] ? on_true[i] : on_false[i];
}

Expand Down

0 comments on commit 0455091

Please sign in to comment.