Skip to content

Commit

Permalink
Feature: output times for lambda inner loop (#3982)
Browse files Browse the repository at this point in the history
  • Loading branch information
hongriTianqi committed Apr 15, 2024
1 parent 43cde6d commit 9e6f12f
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 12 deletions.
34 changes: 33 additions & 1 deletion source/module_hamilt_lcao/module_deltaspin/lambda_loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <iostream>
#include <cmath>
#include <chrono>

#include "basic_funcs.h"

Expand Down Expand Up @@ -31,10 +32,19 @@ void SpinConstrain<std::complex<double>, psi::DEVICE_CPU>::run_lambda_loop(int o
const double zero = 0.0;
const double one = 1.0;

#ifdef __MPI
auto iterstart = MPI_Wtime();
#else
auto iterstart = std::chrono::system_clock::now();
#endif

double inner_loop_duration = 0.0;

this->print_header();
// lambda loop
for (int i_step = 0; i_step < this->nsc_; i_step++)
{
double duration = 0.0;
if (i_step == 0)
{
spin = this->Mi_;
Expand All @@ -53,6 +63,15 @@ void SpinConstrain<std::complex<double>, psi::DEVICE_CPU>::run_lambda_loop(int o
if (i_step >= this->nsc_min_ && GradLessThanBound)
{
add_scalar_multiply_2d(initial_lambda, dnu_last_step, one, this->lambda_);
#ifdef __MPI
duration = (double)(MPI_Wtime() - iterstart);
#else
duration =
(std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now()
- iterstart)).count() / static_cast<double>(1e6);
#endif
inner_loop_duration += duration;
std::cout << "Total TIME(s) = " << inner_loop_duration << std::endl;
this->print_termination();
break;
}
Expand All @@ -71,11 +90,24 @@ void SpinConstrain<std::complex<double>, psi::DEVICE_CPU>::run_lambda_loop(int o
}
mean_error = sum_2d(temp_1) / nat;
rms_error = std::sqrt(mean_error);
if (this->check_rms_stop(outer_step, i_step, rms_error))
#ifdef __MPI
duration = (double)(MPI_Wtime() - iterstart);
#else
duration =
(std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now()
- iterstart)).count() / static_cast<double>(1e6);
#endif
inner_loop_duration += duration;
if (this->check_rms_stop(outer_step, i_step, rms_error, duration, inner_loop_duration))
{
add_scalar_multiply_2d(initial_lambda, dnu_last_step, 1.0, this->lambda_);
break;
}
#ifdef __MPI
iterstart = MPI_Wtime();
#else
iterstart = std::chrono::system_clock::now();
#endif
if (i_step >= 2)
{
beta = mean_error / mean_error_old;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@ void SpinConstrain<std::complex<double>, psi::DEVICE_CPU>::print_termination()
}

template <>
bool SpinConstrain<std::complex<double>, psi::DEVICE_CPU>::check_rms_stop(int outer_step, int i_step, double rms_error)
bool SpinConstrain<std::complex<double>, psi::DEVICE_CPU>::check_rms_stop(int outer_step, int i_step, double rms_error, double duration, double total_duration)
{
std::cout << "Step (Outer -- Inner) = " << outer_step << " -- " << std::left << std::setw(5) << i_step + 1
<< " RMS = " << rms_error << std::endl;
<< " RMS = " << rms_error << " TIME(s) = " << std::setw(11) << duration << std::endl;
if (rms_error < this->sc_thr_ || i_step == this->nsc_ - 1)
{
if (rms_error < this->sc_thr_)
{
std::cout << "Meet convergence criterion ( < " << this->sc_thr_ << " ), exit." << std::endl;
std::cout << "Meet convergence criterion ( < " << this->sc_thr_ << " ), exit.";
std::cout << " Total TIME(s) = " << total_duration << std::endl;
}
else if (i_step == this->nsc_ - 1)
{
std::cout << "Reach maximum number of steps ( " << this->nsc_ << " ), exit." << std::endl;
std::cout << "Reach maximum number of steps ( " << this->nsc_ << " ), exit.";
std::cout << " Total TIME(s) = " << total_duration << std::endl;
}
this->print_termination();
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class SpinConstrain
void run_lambda_loop(int outer_step);

/// lambda loop helper functions
bool check_rms_stop(int outer_step, int i_step, double rms_error);
bool check_rms_stop(int outer_step, int i_step, double rms_error, double duration, double total_duration);

/// apply restriction
void check_restriction(const std::vector<ModuleBase::Vector3<double>>& search, double& alpha_trial);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void SpinConstrain<double, psi::DEVICE_CPU>::run_lambda_loop(int outer_step)
}

template <>
bool SpinConstrain<double, psi::DEVICE_CPU>::check_rms_stop(int outer_step, int i_step, double rms_error)
bool SpinConstrain<double, psi::DEVICE_CPU>::check_rms_stop(int outer_step, int i_step, double rms_error, double duration, double total_duration)
{
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ TEST_F(SpinConstrainTest, CheckRmsStop)
double alpha_trial = 0.01;
double sccut = 3.0;
bool decay_grad_switch = 1;
double duration = 10;
double total_duration = 10;
this->sc.set_input_parameters(sc_thr, nsc, nsc_min, alpha_trial, sccut, decay_grad_switch);
testing::internal::CaptureStdout();
EXPECT_FALSE(sc.check_rms_stop(0, 0, 1e-5));
EXPECT_FALSE(sc.check_rms_stop(0, 11, 1e-5));
EXPECT_TRUE(sc.check_rms_stop(0, 12, 1e-7));
EXPECT_TRUE(sc.check_rms_stop(0, 99, 1e-5));
EXPECT_FALSE(sc.check_rms_stop(0, 0, 1e-5, duration, total_duration));
EXPECT_FALSE(sc.check_rms_stop(0, 11, 1e-5, duration, total_duration));
EXPECT_TRUE(sc.check_rms_stop(0, 12, 1e-7, duration, total_duration));
EXPECT_TRUE(sc.check_rms_stop(0, 99, 1e-5, duration, total_duration));
std::string output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("Step (Outer -- Inner) = 0 -- 1 RMS = 1e-05"));
EXPECT_THAT(output, testing::HasSubstr("Step (Outer -- Inner) = 0 -- 12 RMS = 1e-05"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ TEST_F(SpinConstrainTest, TemplatHelpers)
ModuleBase::ComplexMatrix mud;
ModuleBase::matrix MecMulP;
EXPECT_NO_THROW(sc.collect_MW(MecMulP, mud, 0, 0));
EXPECT_FALSE(sc.check_rms_stop(0, 0, 0.0));
EXPECT_FALSE(sc.check_rms_stop(0, 0, 0.0, 0.0, 0.0));
EXPECT_NO_THROW(sc.print_termination());
EXPECT_NO_THROW(sc.print_header());
std::vector<ModuleBase::Vector3<double>> new_spin, old_spin, new_delta_lambda, old_delta_lambda;
Expand Down

0 comments on commit 9e6f12f

Please sign in to comment.