@@ -30,13 +30,13 @@ contains
3030    module subroutine stdlib_linalg_${ri}$_expm(A, E, order, err)
3131        !> Input matrix A(n, n).
3232        ${rt}$, intent(in) :: A(:, :)
33-         !> [optional] Order of the Pade approximation.
33+         !> Exponential of the input matrix E = exp(A).
34+         ${rt}$, intent(out) :: E(:, :)
35+          !> [optional] Order of the Pade approximation.
3436        integer(ilp), optional, intent(in) :: order
3537        !> [optional] State return flag.
3638        type(linalg_state_type), optional, intent(out) :: err
37-         !> Exponential of the input matrix E = exp(A).
38-         ${rt}$, intent(out) :: E(:, :)
39-         
39+        
4040        type(linalg_state_type) :: err0
4141        integer(ilp) :: lda, n, lde, ne
4242
@@ -68,7 +68,7 @@ contains
6868        type(linalg_state_type), optional, intent(out) :: err
6969
7070        ! Internal variables.
71-         ${rt}$, allocatable     :: A2(:, :), Q(:, :), X(:, :)
71+         ${rt}$, allocatable     :: A2(:, :), Q(:, :), X(:, :), X_tmp(:, :) 
7272        real(${rk}$)            :: a_norm, c
7373        integer(ilp)            :: m, n, ee, k, s, order_, i, j
7474        logical(lk)             :: p
@@ -105,32 +105,29 @@ contains
105105            enddo
106106
107107            ! Iteratively compute the Pade approximation.
108-             block
109-                 ${rt}$, allocatable :: X_tmp(:, :)
110-                 p = .true.
111-                 do k = 2, order_
112-                     c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
113-                     X_tmp = X
114-                     #:if rt.startswith('complex')
115-                     call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
116-                     #:else
117-                     call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
118-                     #:endif
108+             p = .true.
109+             do k = 2, order_
110+                 c = c * (order_ - k + 1) / (k * (2*order_ - k + 1))
111+                 X_tmp = X
112+                 #:if rt.startswith('complex')
113+                 call gemm("N", "N", n, n, n, one_c${rk}$, A2, n, X_tmp, n, zero_c${rk}$, X, n)
114+                 #:else
115+                 call gemm("N", "N", n, n, n, one_${rk}$, A2, n, X_tmp, n, zero_${rk}$, X, n)
116+                 #:endif
117+                 do concurrent(i=1:n, j=1:n)
118+                     A(i, j) = A(i, j) + c*X(i, j)       ! E = E + c*X
119+                 enddo
120+                 if (p) then
119121                    do concurrent(i=1:n, j=1:n)
120-                         A (i, j) = A (i, j) + c*X(i, j)       ! E  = E  + c*X
122+                         Q (i, j) = Q (i, j) + c*X(i, j)   ! Q  = Q  + c*X
121123                    enddo
122-                     if (p) then
123-                         do concurrent(i=1:n, j=1:n)
124-                             Q(i, j) = Q(i, j) + c*X(i, j)   ! Q = Q + c*X
125-                         enddo
126-                     else
127-                         do concurrent(i=1:n, j=1:n)
128-                             Q(i, j) = Q(i, j) - c*X(i, j)   ! Q = Q - c*X
129-                         enddo
130-                     endif
131-                     p = .not. p
132-                 enddo
133-             end block
124+                 else
125+                     do concurrent(i=1:n, j=1:n)
126+                         Q(i, j) = Q(i, j) - c*X(i, j)   ! Q = Q - c*X
127+                     enddo
128+                 endif
129+                 p = .not. p
130+             enddo
134131
135132            block
136133                integer(ilp) :: ipiv(n), info
@@ -139,17 +136,14 @@ contains
139136            end block
140137
141138            ! Matrix squaring.
142-             block
143-                 ${rt}$, allocatable :: E_tmp(:, :)
144-                 do k = 1, s
145-                     E_tmp = A
146-                     #:if rt.startswith('complex')
147-                     call gemm("N", "N", n, n, n, one_c${rk}$, E_tmp, n, E_tmp, n, zero_c${rk}$, A, n)
148-                     #:else
149-                     call gemm("N", "N", n, n, n, one_${rk}$, E_tmp, n, E_tmp, n, zero_${rk}$, A, n)
150-                     #:endif
151-                 enddo
152-             end block
139+             do k = 1, s
140+                 X = A ! Re-use X to minimize allocations.
141+                 #:if rt.startswith('complex')
142+                 call gemm("N", "N", n, n, n, one_c${rk}$, X, n, X, n, zero_c${rk}$, A, n)
143+                 #:else
144+                 call gemm("N", "N", n, n, n, one_${rk}$, X, n, X, n, zero_${rk}$, A, n)
145+                 #:endif
146+             enddo
153147        endif
154148
155149        call linalg_error_handling(err0, err)
0 commit comments