Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kronecker Product addition to stdlib_linalg #700

Merged
merged 12 commits into from
Mar 8, 2023
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(fppFiles
stdlib_linalg.fypp
stdlib_linalg_diag.fypp
stdlib_linalg_outer_product.fypp
stdlib_linalg_kronecker.fypp
stdlib_linalg_cross_product.fypp
stdlib_optval.fypp
stdlib_selection.fypp
Expand Down
15 changes: 15 additions & 0 deletions src/stdlib_linalg.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module stdlib_linalg
public :: eye
public :: trace
public :: outer_product
public :: kronecker_product
public :: cross_product
public :: is_square
public :: is_diagonal
Expand Down Expand Up @@ -93,6 +94,20 @@ module stdlib_linalg
#:endfor
end interface outer_product

interface kronecker_product
!! version: experimental
!!
!! Computes the Kronecker product of two arrays size M1xN1, M2xN2, returning an (M1*M2)x(N1*N2) array
adenchfi marked this conversation as resolved.
Show resolved Hide resolved
!! ([Specification](../page/specs/stdlib_linalg.html#
!! kronecker_product-computes-the-kronecker-product-of-two-matrices))
#:for k1, t1 in RCI_KINDS_TYPES
pure module function kronecker_product_${t1[0]}$${k1}$(A, B) result(C)
${t1}$, intent(in) :: A(:,:), B(:,:)
${t1}$ :: C(size(A,dim=1)*size(B,dim=1),size(A,dim=2)*size(B,dim=2))
end function kronecker_product_${t1[0]}$${k1}$
#:endfor
end interface kronecker_product


! Cross product (of two vectors)
interface cross_product
Expand Down
28 changes: 28 additions & 0 deletions src/stdlib_linalg_kronecker.fypp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#:include "common.fypp"
#:set RCI_KINDS_TYPES = REAL_KINDS_TYPES + CMPLX_KINDS_TYPES + INT_KINDS_TYPES
submodule (stdlib_linalg) stdlib_linalg_kronecker

implicit none

contains

#:for k1, t1 in RCI_KINDS_TYPES
pure module function kronecker_product_${t1[0]}$${k1}$(A, B) result(C)
${t1}$, intent(in) :: A(:,:), B(:,:)
${t1}$ :: C(size(A,dim=1)*size(B,dim=1),size(A,dim=2)*size(B,dim=2))
integer :: m1, n1, m2, n2, maxM1, maxN1, maxM2, maxN2
jvdp1 marked this conversation as resolved.
Show resolved Hide resolved

maxM1 = size(A, dim=1)
maxN1 = size(A, dim=2)
maxM2 = size(B, dim=1)
maxN2 = size(B, dim=2)

do n1=1, maxN1
do m1=1, maxM1
adenchfi marked this conversation as resolved.
Show resolved Hide resolved
! We use the numpy.kron convention for ordering of the matrix elements
jvdp1 marked this conversation as resolved.
Show resolved Hide resolved
C((m1-1)*maxM2+1:m1*maxM2, (n1-1)*maxN2+1:n1*maxN2) = A(m1, n1) * B(:,:)
end do
end do
end function kronecker_product_${t1[0]}$${k1}$
#:endfor
end submodule stdlib_linalg_kronecker