|  | 
|  | 1 | +#:include "common.fypp" | 
|  | 2 | +#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX, REAL_INIT)) | 
|  | 3 | +#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX, CMPLX_INIT)) | 
|  | 4 | +#:set RC_KINDS_TYPES = R_KINDS_TYPES + C_KINDS_TYPES | 
|  | 5 | +submodule (stdlib_linalg) stdlib_linalg_matrix_functions | 
|  | 6 | +    use stdlib_constants | 
|  | 7 | +    use stdlib_linalg_constants | 
|  | 8 | +    use stdlib_linalg_blas, only: gemm | 
|  | 9 | +    use stdlib_linalg_lapack, only: gesv, lacpy | 
|  | 10 | +    use stdlib_linalg_lapack_aux, only: handle_gesv_info | 
|  | 11 | +    use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_ERROR, & | 
|  | 12 | +         LINALG_INTERNAL_ERROR, LINALG_VALUE_ERROR | 
|  | 13 | +    implicit none(type, external) | 
|  | 14 | + | 
|  | 15 | +    character(len=*), parameter :: this = "matrix_exponential" | 
|  | 16 | + | 
|  | 17 | +contains | 
|  | 18 | + | 
|  | 19 | +    #:for k,t,s, i in RC_KINDS_TYPES  | 
|  | 20 | +    module function stdlib_linalg_${i}$_expm_fun(A, order) result(E) | 
|  | 21 | +        !> Input matrix A(n, n). | 
|  | 22 | +        ${t}$, intent(in) :: A(:, :) | 
|  | 23 | +        !> [optional] Order of the Pade approximation. | 
|  | 24 | +        integer(ilp), optional, intent(in) :: order | 
|  | 25 | +        !> Exponential of the input matrix E = exp(A). | 
|  | 26 | +        ${t}$, allocatable :: E(:, :) | 
|  | 27 | + | 
|  | 28 | +        E = A | 
|  | 29 | +        call stdlib_linalg_${i}$_expm_inplace(E, order) | 
|  | 30 | +    end function stdlib_linalg_${i}$_expm_fun | 
|  | 31 | + | 
|  | 32 | +    module subroutine stdlib_linalg_${i}$_expm(A, E, order, err) | 
|  | 33 | +        !> Input matrix A(n, n). | 
|  | 34 | +        ${t}$, intent(in) :: A(:, :) | 
|  | 35 | +        !> Exponential of the input matrix E = exp(A). | 
|  | 36 | +        ${t}$, intent(out) :: E(:, :) | 
|  | 37 | +         !> [optional] Order of the Pade approximation. | 
|  | 38 | +        integer(ilp), optional, intent(in) :: order | 
|  | 39 | +        !> [optional] State return flag. | 
|  | 40 | +        type(linalg_state_type), optional, intent(out) :: err | 
|  | 41 | +        | 
|  | 42 | +        type(linalg_state_type) :: err0 | 
|  | 43 | +        integer(ilp) :: lda, n, lde, ne | 
|  | 44 | +          | 
|  | 45 | +        ! Check E sizes | 
|  | 46 | +        lda = size(A, 1, kind=ilp) ; n = size(A, 2, kind=ilp) | 
|  | 47 | +        lde = size(E, 1, kind=ilp) ; ne = size(E, 2, kind=ilp) | 
|  | 48 | +           | 
|  | 49 | +        if (lda<1 .or. n<1 .or. lda/=n .or. lde/=n .or. ne/=n) then      | 
|  | 50 | +            err0 = linalg_state_type(this,LINALG_VALUE_ERROR, & | 
|  | 51 | +                                     'invalid matrix sizes: A must be square (lda=', lda, ', n=', n, ')', & | 
|  | 52 | +                                     ' E must be square (lde=', lde, ', ne=', ne, ')') | 
|  | 53 | +        else | 
|  | 54 | +            call lacpy("n", n, n, A, n, E, n) ! E = A | 
|  | 55 | +            call stdlib_linalg_${i}$_expm_inplace(E, order, err0) | 
|  | 56 | +        endif | 
|  | 57 | +         | 
|  | 58 | +        ! Process output and return | 
|  | 59 | +        call linalg_error_handling(err0,err) | 
|  | 60 | + | 
|  | 61 | +        return | 
|  | 62 | +    end subroutine stdlib_linalg_${i}$_expm | 
|  | 63 | + | 
|  | 64 | +    module subroutine stdlib_linalg_${i}$_expm_inplace(A, order, err) | 
|  | 65 | +        !> Input matrix A(n, n) / Output matrix exponential. | 
|  | 66 | +        ${t}$, intent(inout) :: A(:, :) | 
|  | 67 | +        !> [optional] Order of the Pade approximation. | 
|  | 68 | +        integer(ilp), optional, intent(in) :: order | 
|  | 69 | +        !> [optional] State return flag. | 
|  | 70 | +        type(linalg_state_type), optional, intent(out) :: err | 
|  | 71 | + | 
|  | 72 | +        ! Internal variables. | 
|  | 73 | +        ${t}$                   :: A2(size(A, 1), size(A, 2)), Q(size(A, 1), size(A, 2)) | 
|  | 74 | +        ${t}$                   :: X(size(A, 1), size(A, 2)), X_tmp(size(A, 1), size(A, 2)) | 
|  | 75 | +        real(${k}$)             :: a_norm, c | 
|  | 76 | +        integer(ilp)            :: m, n, ee, k, s, order_, i, j | 
|  | 77 | +        logical(lk)             :: p | 
|  | 78 | +        type(linalg_state_type) :: err0 | 
|  | 79 | + | 
|  | 80 | +        ! Deal with optional args. | 
|  | 81 | +        order_ = 10 ; if (present(order)) order_ = order | 
|  | 82 | + | 
|  | 83 | +        ! Problem's dimension. | 
|  | 84 | +        m = size(A, dim=1, kind=ilp) ; n = size(A, dim=2, kind=ilp) | 
|  | 85 | + | 
|  | 86 | +        if (m /= n) then | 
|  | 87 | +            err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'Invalid matrix size A=',[m, n]) | 
|  | 88 | +        else if (order_ < 0) then | 
|  | 89 | +            err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'Order of Pade approximation & | 
|  | 90 | +                                    needs to be positive, order=', order_) | 
|  | 91 | +        else | 
|  | 92 | +            ! Compute the L-infinity norm. | 
|  | 93 | +            a_norm = mnorm(A, "inf") | 
|  | 94 | + | 
|  | 95 | +            ! Determine scaling factor for the matrix. | 
|  | 96 | +            ee = int(log(a_norm) / log2_${k}$, kind=ilp) + 1 | 
|  | 97 | +            s  = max(0, ee+1) | 
|  | 98 | + | 
|  | 99 | +            ! Scale the input matrix & initialize polynomial. | 
|  | 100 | +            A2 = A/2.0_${k}$**s | 
|  | 101 | +            call lacpy("n", n, n, A2, n, X, n) ! X = A2 | 
|  | 102 | + | 
|  | 103 | +            ! First step of the Pade approximation. | 
|  | 104 | +            c = 0.5_${k}$ | 
|  | 105 | +            do concurrent(i=1:n, j=1:n) | 
|  | 106 | +                A(i, j) = merge(1.0_${k}$ + c*A2(i, j), c*A2(i, j), i == j) | 
|  | 107 | +                Q(i, j) = merge(1.0_${k}$ - c*A2(i, j), -c*A2(i, j), i == j) | 
|  | 108 | +            enddo | 
|  | 109 | + | 
|  | 110 | +            ! Iteratively compute the Pade approximation. | 
|  | 111 | +            p = .true. | 
|  | 112 | +            do k = 2, order_ | 
|  | 113 | +                c = c * (order_ - k + 1) / (k * (2*order_ - k + 1)) | 
|  | 114 | +                call lacpy("n", n, n, X, n, X_tmp, n) ! X_tmp = X | 
|  | 115 | +                call gemm("N", "N", n, n, n, one_${s}$, A2, n, X_tmp, n, zero_${s}$, X, n) | 
|  | 116 | +                do concurrent(i=1:n, j=1:n) | 
|  | 117 | +                    A(i, j) = A(i, j) + c*X(i, j)       ! E = E + c*X | 
|  | 118 | +                    Q(i, j) = merge(Q(i, j) + c*X(i, j), Q(i, j) - c*X(i, j), p) | 
|  | 119 | +                enddo | 
|  | 120 | +                p = .not. p | 
|  | 121 | +            enddo | 
|  | 122 | + | 
|  | 123 | +            block | 
|  | 124 | +                integer(ilp) :: ipiv(n), info | 
|  | 125 | +                call gesv(n, n, Q, n, ipiv, A, n, info) ! E = inv(Q) @ E | 
|  | 126 | +                call handle_gesv_info(this, info, n, n, n, err0) | 
|  | 127 | +            end block | 
|  | 128 | + | 
|  | 129 | +            ! Matrix squaring. | 
|  | 130 | +            do k = 1, s | 
|  | 131 | +                call lacpy("n", n, n, A, n, X, n) ! X = A | 
|  | 132 | +                call gemm("N", "N", n, n, n, one_${s}$, X, n, X, n, zero_${s}$, A, n) | 
|  | 133 | +            enddo | 
|  | 134 | +        endif | 
|  | 135 | +         | 
|  | 136 | +        call linalg_error_handling(err0, err) | 
|  | 137 | + | 
|  | 138 | +        return | 
|  | 139 | +    end subroutine stdlib_linalg_${i}$_expm_inplace | 
|  | 140 | +    #:endfor | 
|  | 141 | + | 
|  | 142 | +end submodule stdlib_linalg_matrix_functions | 
0 commit comments