Skip to content

Commit 3192a16

Browse files
committed
Removed one block and one temporary variable allocation.
1 parent 5eb5273 commit 3192a16

File tree

2 files changed

+36
-42
lines changed

2 files changed

+36
-42
lines changed

src/stdlib_linalg.fypp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,7 +1693,7 @@ module stdlib_linalg
16931693
!! `real` and `complex`.
16941694
!!
16951695
!! By default, the order of the Pade approximation is set to 10. It can be changed
1696-
!! via the `order` argument which must be non-negative.
1696+
!! via the `order` argument that must be non-negative.
16971697
!!
16981698
!! If the input matrix is non-square or the order of the Pade approximation is
16991699
!! negative, the function returns an error state.
@@ -1738,7 +1738,7 @@ module stdlib_linalg
17381738
!! `real` and `complex`.
17391739
!!
17401740
!! By default, the order of the Pade approximation is set to 10. It can be changed
1741-
!! via the `order` argument which must be non-negative.
1741+
!! via the `order` argument that must be non-negative.
17421742
!!
17431743
!! If the input matrix is non-square or the order of the Pade approximation is
17441744
!! negative, the function returns an error state.

src/stdlib_linalg_matrix_functions.fypp

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ contains
3030
module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
3131
!> Input matrix A(n, n).
3232
${rt}$, intent(in) :: A(:, :)
33-
!> [optional] Order of the Pade approximation.
33+
!> Exponential of the input matrix E = exp(A).
34+
${rt}$, intent(out) :: E(:, :)
35+
!> [optional] Order of the Pade approximation.
3436
integer(ilp), optional, intent(in) :: order
3537
!> [optional] State return flag.
3638
type(linalg_state_type), optional, intent(out) :: err
37-
!> Exponential of the input matrix E = exp(A).
38-
${rt}$, intent(out) :: E(:, :)
39-
39+
4040
type(linalg_state_type) :: err0
4141
integer(ilp) :: lda, n, lde, ne
4242

@@ -68,7 +68,7 @@ contains
6868
type(linalg_state_type), optional, intent(out) :: err
6969

7070
! Internal variables.
71-
${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :)
71+
${rt}$, allocatable :: A2(:, :), Q(:, :), X(:, :), X_tmp(:, :)
7272
real(${rk}$) :: a_norm, c
7373
integer(ilp) :: m, n, ee, k, s, order_, i, j
7474
logical(lk) :: p
@@ -105,32 +105,29 @@ contains
105105
enddo
106106

107107
! Iteratively compute the Pade approximation.
108-
block
109-
${rt}$, allocatable :: X_tmp(:, :)
110-
p = .true.
111-
do k = 2, order_
112-
c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
113-
X_tmp = X
114-
#:if rt.startswith('complex')
115-
call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
116-
#:else
117-
call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
118-
#:endif
108+
p = .true.
109+
do k = 2, order_
110+
c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
111+
X_tmp = X
112+
#:if rt.startswith('complex')
113+
call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
114+
#:else
115+
call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
116+
#:endif
117+
do concurrent(i=1:n, j=1:n)
118+
A(i, j) = A(i, j) + c*X(i, j) ! E = E + c*X
119+
enddo
120+
if (p) then
119121
do concurrent(i=1:n, j=1:n)
120-
A(i, j) = A(i, j) + c*X(i, j) ! E = E + c*X
122+
Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
121123
enddo
122-
if (p) then
123-
do concurrent(i=1:n, j=1:n)
124-
Q(i, j) = Q(i, j) + c*X(i, j) ! Q = Q + c*X
125-
enddo
126-
else
127-
do concurrent(i=1:n, j=1:n)
128-
Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
129-
enddo
130-
endif
131-
p = .not. p
132-
enddo
133-
end block
124+
else
125+
do concurrent(i=1:n, j=1:n)
126+
Q(i, j) = Q(i, j) - c*X(i, j) ! Q = Q - c*X
127+
enddo
128+
endif
129+
p = .not. p
130+
enddo
134131

135132
block
136133
integer(ilp) :: ipiv(n), info
@@ -139,17 +136,14 @@ contains
139136
end block
140137

141138
! Matrix squaring.
142-
block
143-
${rt}$, allocatable :: E_tmp(:, :)
144-
do k = 1, s
145-
E_tmp = A
146-
#:if rt.startswith('complex')
147-
call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, A, n)
148-
#:else
149-
call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, A, n)
150-
#:endif
151-
enddo
152-
end block
139+
do k = 1, s
140+
X = A ! Re-use X to minimize allocations.
141+
#:if rt.startswith('complex')
142+
call gemm("N", "N", n, n, n, one_c${rk}$, X, n, X, n, zero_c${rk}$, A, n)
143+
#:else
144+
call gemm("N", "N", n, n, n, one_${rk}$, X, n, X, n, zero_${rk}$, A, n)
145+
#:endif
146+
enddo
153147
endif
154148

155149
call linalg_error_handling(err0, err)

0 commit comments

Comments
 (0)