Skip to content

Commit

Permalink
Merge pull request #16839 from bangerth/tensor-product
Browse files Browse the repository at this point in the history
Make tensor*tensor implementation easier to read.
  • Loading branch information
tamiko committed Apr 4, 2024
2 parents 2727c89 + 0d23f44 commit 3910325
Showing 1 changed file with 47 additions and 9 deletions.
56 changes: 47 additions & 9 deletions include/deal.II/base/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2343,17 +2343,55 @@ constexpr inline DEAL_II_ALWAYS_INLINE
operator*(const Tensor<rank_1, dim, Number> &src1,
const Tensor<rank_2, dim, OtherNumber> &src2)
{
typename Tensor<rank_1 + rank_2 - 2,
dim,
typename ProductType<Number, OtherNumber>::type>::tensor_type
result{};
// Treat some common cases separately. Specifically, these are the dot
// product between two rank-1 tensors, and the product between a
// rank-2 tensor and a rank-1 tensor. Both of these lead to a linear
// loop over adjacent memory and can be dealt with efficiently; in the
// latter case (rank-2 times rank-1), we implement things by deferring
// to rank-1 times rank-1 dot products.
if constexpr ((rank_1 == 1) && (rank_2 == 1))
{
// This is a dot product between two rank-1 tensors. Write it out as
// a linear loop:
static_assert(dim > 0, "Tensors cannot have dimension zero.");
typename ProductType<Number, OtherNumber>::type sum = src1[0] * src2[0];
for (unsigned int i = 1; i < dim; ++i)
sum += src1[i] * src2[i];

TensorAccessors::internal::
ReorderedIndexView<0, rank_2, const Tensor<rank_2, dim, OtherNumber>>
reordered = TensorAccessors::reordered_index_view<0, rank_2>(src2);
TensorAccessors::contract<1, rank_1, rank_2, dim>(result, src1, reordered);
return sum;
}
else if constexpr ((rank_1 == 2) && (rank_2 == 1))
{
// This is a product between a rank-2 and a rank-1 tensor. This
// corresponds to taking dot products between the rows of the former
// and the latter.
typename Tensor<
rank_1 + rank_2 - 2,
dim,
typename ProductType<Number, OtherNumber>::type>::tensor_type result;
for (unsigned int i = 0; i < dim; ++i)
result[i] += src1[i] * src2;

return result;
return result;
}
else
{
// Treat all of the other cases using the more general contraction
// machinery.
typename Tensor<
rank_1 + rank_2 - 2,
dim,
typename ProductType<Number, OtherNumber>::type>::tensor_type result{};

TensorAccessors::internal::
ReorderedIndexView<0, rank_2, const Tensor<rank_2, dim, OtherNumber>>
reordered = TensorAccessors::reordered_index_view<0, rank_2>(src2);
TensorAccessors::contract<1, rank_1, rank_2, dim>(result,
src1,
reordered);

return result;
}
}


Expand Down

0 comments on commit 3910325

Please sign in to comment.