Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tensor*tensor implementation easier to read. #16839

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 47 additions & 9 deletions include/deal.II/base/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2343,17 +2343,55 @@ constexpr inline DEAL_II_ALWAYS_INLINE
operator*(const Tensor<rank_1, dim, Number> &src1,
const Tensor<rank_2, dim, OtherNumber> &src2)
{
typename Tensor<rank_1 + rank_2 - 2,
dim,
typename ProductType<Number, OtherNumber>::type>::tensor_type
result{};
// Treat some common cases separately. Specifically, these are the dot
// product between two rank-1 tensors, and the product between a
// rank-2 tensor and a rank-1 tensor. Both of these lead to a linear
// loop over adjacent memory and can be dealt with efficiently; in the
// latter case (rank-2 times rank-1), we implement things by deferring
// to rank-1 times rank-1 dot products.
if constexpr ((rank_1 == 1) && (rank_2 == 1))
{
// This is a dot product between two rank-1 tensors. Write it out as
// a linear loop:
static_assert(dim > 0, "Tensors cannot have dimension zero.");
typename ProductType<Number, OtherNumber>::type sum = src1[0] * src2[0];
for (unsigned int i = 1; i < dim; ++i)
sum += src1[i] * src2[i];

TensorAccessors::internal::
ReorderedIndexView<0, rank_2, const Tensor<rank_2, dim, OtherNumber>>
reordered = TensorAccessors::reordered_index_view<0, rank_2>(src2);
TensorAccessors::contract<1, rank_1, rank_2, dim>(result, src1, reordered);
return sum;
}
else if constexpr ((rank_1 == 2) && (rank_2 == 1))
{
// This is a product between a rank-2 and a rank-1 tensor. This
// corresponds to taking dot products between the rows of the former
// and the latter.
typename Tensor<
rank_1 + rank_2 - 2,
dim,
typename ProductType<Number, OtherNumber>::type>::tensor_type result;
for (unsigned int i = 0; i < dim; ++i)
result[i] += src1[i] * src2;

return result;
return result;
}
else
{
// Treat all of the other cases using the more general contraction
// machinery.
typename Tensor<
rank_1 + rank_2 - 2,
dim,
typename ProductType<Number, OtherNumber>::type>::tensor_type result{};

TensorAccessors::internal::
ReorderedIndexView<0, rank_2, const Tensor<rank_2, dim, OtherNumber>>
reordered = TensorAccessors::reordered_index_view<0, rank_2>(src2);
TensorAccessors::contract<1, rank_1, rank_2, dim>(result,
src1,
reordered);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me considerable effort to come up with one generic implementation instead of special casing contractions of different template parameters fortran style.

And I fear right now that you are undoing a lot of this work.

Despite the rather intimidating template magic - the entire tensor contraction logic boils down to one loop where we do this:

     for (unsigned int i = 0; i < dim; ++i)
        result[i] += src1[i] * src2;

So, why not adding the vectorization bits there?

Relatedly, we could also ditch the Reordered IndexView for the special case of a "last and first index" contraction and provide a specialized, vectorized recursive template for that one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course that's the plan. I needed to create a place where I can put the vectorization in the next patch. This patch creates this place. It also, perhaps, makes the compiler's life easier for the two simplest cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, though: I'm not ditching the generic implementation. There remain the various contract() functions that I don't intend to touch. They're the main users of your general approach, which I continue to admire. I just think that in these two specific cases, things are so simple that we can perhaps allow us to just write the obvious code :-)


return result;
}
}


Expand Down