Skip to content

Commit

Permalink
Add __CHECK_DIAG flag to validate diagonalisation
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrack committed Feb 7, 2019
1 parent 75b56d4 commit 1488e6b
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 8 deletions.
89 changes: 85 additions & 4 deletions src/fm/cp_fm_diag.F
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,22 @@ SUBROUTINE choose_eigv_solver(matrix, eigenvectors, eigenvalues, info)
INTEGER, INTENT(OUT), OPTIONAL :: info

CHARACTER(len=*), PARAMETER :: routineN = 'choose_eigv_solver', &
routineP = moduleN//':'//routineN
routineP = moduleN//':'//routineN

INTEGER :: myinfo, nmo
#if defined(__CHECK_DIAG)
CHARACTER(LEN=5), DIMENSION(3), PARAMETER :: diag_driver = (/"SYEVD", &
"SYEVR", &
"ELPA "/)
REAL(KIND=dp), PARAMETER :: eps = 1.0E-12_dp
INTEGER :: i, j, n
#if (defined(__SCALAPACK) || defined(__SCALAPACK2))
TYPE(cp_blacs_env_type), POINTER :: context
INTEGER :: il, jl, ipcol, iprow, &
mypcol, myprow, npcol, nprow
INTEGER, DIMENSION(9) :: desca
#endif
#endif

myinfo = 0

Expand All @@ -186,13 +199,81 @@ SUBROUTINE choose_eigv_solver(matrix, eigenvectors, eigenvalues, info)
ELSE IF (diag_type == 2) THEN
CALL cp_fm_syevr(matrix, eigenvectors, eigenvalues, 1, nmo)
ELSE IF (diag_type == 1) THEN
CALL cp_fm_syevd(matrix, eigenvectors, eigenvalues, info=myinfo)
IF (PRESENT(info)) THEN
CALL cp_fm_syevd(matrix, eigenvectors, eigenvalues, info=myinfo)
ELSE
CALL cp_fm_syevd(matrix, eigenvectors, eigenvalues)
END IF
ELSE
CPABORT("Unknown DIAG type")
END IF

IF (PRESENT(info)) info = myinfo

#if defined(__CHECK_DIAG)
#if (defined(__SCALAPACK) || defined(__SCALAPACK2))
n = eigenvectors%matrix_struct%nrow_global
CALL cp_fm_gemm("T", "N", nmo, nmo, n, 1.0_dp, eigenvectors, eigenvectors, 0.0_dp, matrix)
context => matrix%matrix_struct%context
myprow = context%mepos(1)
mypcol = context%mepos(2)
nprow = context%num_pe(1)
npcol = context%num_pe(2)
desca(:) = matrix%matrix_struct%descriptor(:)
DO i = 1, nmo
DO j = 1, nmo
CALL infog2l(i, j, desca, nprow, npcol, myprow, mypcol, il, jl, iprow, ipcol)
IF ((iprow == myprow) .AND. (ipcol == mypcol)) THEN
IF (i == j) THEN
IF (ABS(matrix%local_data(il, jl)-1.0_dp) > eps) THEN
WRITE (UNIT=*, FMT="(/,T2,A,/,T2,A,I0,A,I0,A,F0.12,/,T2,A)") &
"The eigenvectors returned by "//TRIM(diag_driver(diag_type))//" are not orthonormal", &
"Matrix element (", i, ", ", j, ") = ", matrix%local_data(il, jl), &
"The expected value is 1"
CPABORT("Matrix diagonalization failed")
END IF
ELSE
IF (ABS(matrix%local_data(il, jl)) > eps) THEN
WRITE (UNIT=*, FMT="(/,T2,A,/,T2,A,I0,A,I0,A,F0.12,/,T2,A)") &
"The eigenvectors returned by "//TRIM(diag_driver(diag_type))//" are not orthonormal", &
"Matrix element (", i, ", ", j, ") = ", matrix%local_data(il, jl), &
"The expected value is 0"
CPABORT("Matrix diagonalization failed")
END IF
END IF
END IF
END DO
END DO
#else
n = SIZE(eigenvectors%local_data, 1)
CALL dgemm("T", "N", nmo, nmo, n, 1.0_dp, &
eigenvectors%local_data(1, 1), n, &
eigenvectors%local_data(1, 1), n, &
0.0_dp, matrix%local_data(1, 1), n)
DO i = 1, nmo
DO j = 1, nmo
IF (i == j) THEN
IF (ABS(matrix%local_data(i, j)-1.0_dp) > eps) THEN
WRITE (UNIT=*, FMT="(/,T2,A,/,T2,A,I0,A,I0,A,F0.12,/,T2,A)") &
"The eigenvectors returned by "//TRIM(diag_driver(diag_type))//" are not orthonormal", &
"Matrix element (", i, ", ", j, ") = ", matrix%local_data(i, j), &
"The expected value is 1"
CPABORT("Matrix diagonalization failed")
END IF
ELSE
IF (ABS(matrix%local_data(i, j)) > eps) THEN
WRITE (UNIT=*, FMT="(/,T2,A,/,T2,A,I0,A,I0,A,F0.12,/,T2,A)") &
"The eigenvectors returned by "//TRIM(diag_driver(diag_type))//" are not orthonormal", &
"Matrix element (", i, ", ", j, ") = ", matrix%local_data(i, j), &
"The expected value is 0"
CPABORT("Matrix diagonalization failed")
END IF
END IF
END DO
END DO
#endif
#endif

END SUBROUTINE choose_eigv_solver

! **************************************************************************************************
Expand Down Expand Up @@ -252,7 +333,6 @@ SUBROUTINE cp_fm_syevd(matrix, eigenvectors, eigenvalues, info)
#else

!MK Retrieve the optimal work array sizes first
myinfo = 0
lwork = -1
liwork = -1
m => matrix%local_data
Expand All @@ -273,6 +353,7 @@ SUBROUTINE cp_fm_syevd(matrix, eigenvectors, eigenvalues, info)
lwork = INT(work(1))
DEALLOCATE (work)
ALLOCATE (work(lwork))
work(:) = 0.0_dp

liwork = iwork(1)
DEALLOCATE (iwork)
Expand All @@ -291,7 +372,7 @@ SUBROUTINE cp_fm_syevd(matrix, eigenvectors, eigenvalues, info)
DEALLOCATE (work)
#endif

IF (PRESENT(info)) myinfo = 0
IF (PRESENT(info)) info = myinfo

nmo = SIZE(eigenvalues, 1)
IF (nmo > n) THEN
Expand Down
17 changes: 13 additions & 4 deletions src/fm/cp_fm_types.F
Original file line number Diff line number Diff line change
Expand Up @@ -2148,15 +2148,17 @@ END SUBROUTINE cp_fm_write_unformatted
!> \param fm the matrix to be outputted
!> \param unit the unit number for I/O
!> \param header optional header
!> \param value_format ...
! **************************************************************************************************
SUBROUTINE cp_fm_write_formatted(fm, unit, header)
SUBROUTINE cp_fm_write_formatted(fm, unit, header, value_format)
TYPE(cp_fm_type), POINTER :: fm
INTEGER :: unit
CHARACTER(LEN=*), OPTIONAL :: header
CHARACTER(LEN=*), OPTIONAL :: header, value_format
CHARACTER(len=*), PARAMETER :: routineN = 'cp_fm_write_formatted', &
routineP = moduleN//':'//routineN
CHARACTER(LEN=21) :: my_value_format
INTEGER :: handle, i, j, max_block, &
ncol_global, nrow_global, &
nrow_block
Expand All @@ -2178,6 +2180,13 @@ SUBROUTINE cp_fm_write_formatted(fm, unit, header)
CALL cp_fm_get_info(fm, nrow_global=nrow_global, ncol_global=ncol_global, ncol_block=max_block, &
nrow_block=nrow_block, para_env=para_env)
IF (PRESENT(value_format)) THEN
CPASSERT(LEN_TRIM(ADJUSTL(value_format)) < 11)
my_value_format = "(I10, I10, "//TRIM(ADJUSTL(value_format))//")"
ELSE
my_value_format = "(I10, I10, ES24.12)"
END IF
IF (unit > 0) THEN
IF (PRESENT(header)) WRITE (unit, *) header
WRITE (unit, "(A2, A8, A10, A24)") "#", "Row", "Column", ADJUSTL("Value")
Expand Down Expand Up @@ -2232,7 +2241,7 @@ SUBROUTINE cp_fm_write_formatted(fm, unit, header)
IF (unit > 0) THEN
DO j = 1, i_block
DO k = (j-1)*nrow_global+1, nrow_global*j
WRITE (unit, "(I10, I10, ES24.12)") irow, icol, vecbuf(k)
WRITE (UNIT=unit, FMT=my_value_format) irow, icol, vecbuf(k)
irow = irow+1
IF (irow > nrow_global) THEN
irow = 1
Expand All @@ -2254,7 +2263,7 @@ SUBROUTINE cp_fm_write_formatted(fm, unit, header)
IF (unit > 0) THEN
DO j = 1, ncol_global
DO i = 1, nrow_global
WRITE (unit, "(I10, I10, ES24.12)") i, j, fm%local_data(i, j)
WRITE (UNIT=unit, FMT=my_value_format) i, j, fm%local_data(i, j)
END DO
END DO
END IF
Expand Down

0 comments on commit 1488e6b

Please sign in to comment.