diff --git a/include/omath/linear_algebra/mat.hpp b/include/omath/linear_algebra/mat.hpp index 4f5a7ed0..99f87757 100644 --- a/include/omath/linear_algebra/mat.hpp +++ b/include/omath/linear_algebra/mat.hpp @@ -155,28 +155,11 @@ namespace omath constexpr Mat operator*(const Mat& other) const { - Mat result; - if constexpr (StoreType == MatStoreType::ROW_MAJOR) - for (std::size_t i = 0; i < Rows; ++i) - for (std::size_t k = 0; k < Columns; ++k) - { - const Type aik = at(i, k); - for (std::size_t j = 0; j < OtherColumns; ++j) - result.at(i, j) += aik * other.at(k, j); - } - else if constexpr (StoreType == MatStoreType::COLUMN_MAJOR) - for (std::size_t j = 0; j < OtherColumns; ++j) - for (std::size_t k = 0; k < Columns; ++k) - { - const Type bkj = other.at(k, j); - for (std::size_t i = 0; i < Rows; ++i) - result.at(i, j) += at(i, k) * bkj; - } - else - std::unreachable(); - - return result; + return cache_friendly_multiply_row_major(other); + if constexpr (StoreType == MatStoreType::COLUMN_MAJOR) + return cache_friendly_multiply_col_major(other); + std::unreachable(); } constexpr Mat& operator*=(const Type& f) noexcept @@ -378,6 +361,36 @@ namespace omath private: std::array m_data; + + template [[nodiscard]] + constexpr Mat + cache_friendly_multiply_row_major(const Mat& other) const + { + Mat result; + for (std::size_t i = 0; i < Rows; ++i) + for (std::size_t k = 0; k < Columns; ++k) + { + const Type aik = at(i, k); + for (std::size_t j = 0; j < OtherColumns; ++j) + result.at(i, j) += aik * other.at(k, j); + } + return result; + } + + template [[nodiscard]] + constexpr Mat cache_friendly_multiply_col_major( + const Mat& other) const + { + Mat result; + for (std::size_t j = 0; j < OtherColumns; ++j) + for (std::size_t k = 0; k < Columns; ++k) + { + const Type bkj = other.at(k, j); + for (std::size_t i = 0; i < Rows; ++i) + result.at(i, j) += at(i, k) * bkj; + } + return result; + } }; template [[nodiscard]]