diff --git a/tasks/mpi/example/include/ops_mpi.hpp b/tasks/mpi/example/include/ops_mpi.hpp index afab7aae8..700f1be02 100644 --- a/tasks/mpi/example/include/ops_mpi.hpp +++ b/tasks/mpi/example/include/ops_mpi.hpp @@ -18,6 +18,10 @@ class TestTaskMPI : public ppc::core::Task { private: std::vector input_, output_; int rc_size_{}; + + void MultiplyMatrixBasedOnRank(); + void MultiplyRowMajor(); + void MultiplyColumnMajor(); }; } // namespace nesterov_a_test_task_mpi diff --git a/tasks/mpi/example/src/ops_mpi.cpp b/tasks/mpi/example/src/ops_mpi.cpp index 3c5af2805..331d278df 100644 --- a/tasks/mpi/example/src/ops_mpi.cpp +++ b/tasks/mpi/example/src/ops_mpi.cpp @@ -25,26 +25,40 @@ bool nesterov_a_test_task_mpi::TestTaskMPI::ValidationImpl() { } bool nesterov_a_test_task_mpi::TestTaskMPI::RunImpl() { + MultiplyMatrixBasedOnRank(); + return true; +} + +void nesterov_a_test_task_mpi::TestTaskMPI::MultiplyMatrixBasedOnRank() { int rank = -1; MPI_Comm_rank(MPI_COMM_WORLD, &rank); - auto multiply = [this](bool row_major) { - for (int i = 0; i < rc_size_; ++i) { - for (int j = 0; j < rc_size_; ++j) { - int sum = 0; - for (int k = 0; k < rc_size_; ++k) { - int a = input_[(row_major ? i : k) * rc_size_ + (row_major ? k : i)]; - int b = input_[(row_major ? k : j) * rc_size_ + (row_major ? j : k)]; - sum += a * b; - } - output_[(i * rc_size_) + j] += sum; + if (rank == 0) { + MultiplyRowMajor(); + } else { + MultiplyColumnMajor(); + } + MPI_Barrier(MPI_COMM_WORLD); +} + +void nesterov_a_test_task_mpi::TestTaskMPI::MultiplyRowMajor() { + for (int i = 0; i < rc_size_; ++i) { + for (int j = 0; j < rc_size_; ++j) { + for (int k = 0; k < rc_size_; ++k) { + output_[(i * rc_size_) + j] += input_[(i * rc_size_) + k] * input_[(k * rc_size_) + j]; } } - }; + } +} - multiply(rank == 0); - MPI_Barrier(MPI_COMM_WORLD); - return true; +void nesterov_a_test_task_mpi::TestTaskMPI::MultiplyColumnMajor() { + for (int j = 0; j < rc_size_; ++j) { + for (int k = 0; k < rc_size_; ++k) { + for (int i = 0; i < rc_size_; ++i) { + output_[(i * rc_size_) + j] += input_[(i * rc_size_) + k] * input_[(k * rc_size_) + j]; + } + } + } } bool nesterov_a_test_task_mpi::TestTaskMPI::PostProcessingImpl() {