From 109e3554cb11c0108aa7fd22292196335e25e5fa Mon Sep 17 00:00:00 2001 From: Kazuki Komatsu Date: Mon, 11 Feb 2019 11:51:44 +0900 Subject: [PATCH 1/4] Add complex version gelsd --- source/mir/lapack.d | 72 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/source/mir/lapack.d b/source/mir/lapack.d index c310bac..493b774 100644 --- a/source/mir/lapack.d +++ b/source/mir/lapack.d @@ -12,6 +12,7 @@ import mir.ndslice.slice; import mir.ndslice.topology: retro; import mir.ndslice.iterator; import mir.utility: min, max; +import mir.internal.utility : realType, isComplex; static import lapack; @@ -181,6 +182,7 @@ size_t gelsd_wq(T)( Slice!(T*, 2, Canonical) b, ref size_t liwork, ) + if(!isComplex!T) { assert(b.length!1 == a.length!1); @@ -203,10 +205,45 @@ size_t gelsd_wq(T)( return cast(size_t) work; } + +/// ditto +size_t gelsd_wq(T)( + Slice!(T*, 2, Canonical) a, + Slice!(T*, 2, Canonical) b, + ref size_t lrwork, + ref size_t liwork, + ) + if(isComplex!T) +{ + assert(b.length!1 == a.length!1); + + lapackint m = cast(lapackint) a.length!1; + lapackint n = cast(lapackint) a.length!0; + lapackint nrhs = cast(lapackint) b.length; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldb = cast(lapackint) b._stride.max(1); + realType!T rcond = void; + lapackint rank = void; + T work = void; + lapackint lwork = -1; + realType!T rwork = void; + lapackint iwork = void; + lapackint info = void; + + lapack.gelsd_(m, n, nrhs, a.iterator, lda, b.iterator, ldb, null, rcond, rank, &work, lwork, &rwork, &iwork, info); + + assert(info == 0); + lrwork = cast(size_t)rwork; + liwork = iwork; + return cast(size_t) work; +} + unittest { alias s = gelsd_wq!float; alias d = gelsd_wq!double; + alias c = gelsd_wq!cfloat; + alias z = gelsd_wq!cdouble; } /// @@ -219,6 +256,7 @@ size_t gelsd(T)( Slice!(T*) work, Slice!(lapackint*) iwork, ) + if(!isComplex!T) { assert(b.length!1 == a.length!1); assert(s.length == min(a.length!0, a.length!1)); @@ -239,10 +277,44 @@ size_t gelsd(T)( return info; } +/// ditto +size_t gelsd(T)( + Slice!(T*, 2, Canonical) a, + Slice!(T*, 2, Canonical) b, + Slice!(realType!T*) s, + realType!T rcond, + ref size_t rank, + Slice!(T*) work, + Slice!(realType!T*) rwork, + Slice!(lapackint*) iwork, + ) + if(isComplex!T) +{ + assert(b.length!1 == a.length!1); + assert(s.length == min(a.length!0, a.length!1)); + + lapackint m = cast(lapackint) a.length!1; + lapackint n = cast(lapackint) a.length!0; + lapackint nrhs = cast(lapackint) b.length; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldb = cast(lapackint) b._stride.max(1); + lapackint rank_ = void; + lapackint lwork = cast(lapackint) work.length; + lapackint info = void; + + lapack.gelsd_(m, n, nrhs, a.iterator, lda, b.iterator, ldb, s.iterator, rcond, rank_, work.iterator, lwork, rwork.iterator, iwork.iterator, info); + + assert(info >= 0); + rank = rank_; + return info; +} + unittest { alias s = gelsd!float; alias d = gelsd!double; + alias c = gelsd!cfloat; + alias z = gelsd!cdouble; } /// `gesdd` work space query From 3298588f69a1655742e599424de8d64fa11bbadc Mon Sep 17 00:00:00 2001 From: Kazuki Komatsu Date: Mon, 11 Feb 2019 12:03:29 +0900 Subject: [PATCH 2/4] Add complex version gesdd --- source/mir/lapack.d | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/source/mir/lapack.d b/source/mir/lapack.d index 493b774..a416c7f 100644 --- a/source/mir/lapack.d +++ b/source/mir/lapack.d @@ -334,7 +334,14 @@ size_t gesdd_wq(T)( lapackint lwork = -1; lapackint info = void; - lapack.gesdd_(jobz, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, null, info); + static if(!isComplex!T) + { + lapack.gesdd_(jobz, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, null, info); + } + else + { + lapack.gesdd_(jobz, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, null, null, info); + } assert(info == 0); return cast(size_t) work; @@ -344,6 +351,8 @@ unittest { alias s = gesdd_wq!float; alias d = gesdd_wq!double; + alias c = gesdd_wq!cfloat; + alias z = gesdd_wq!cdouble; } /// @@ -356,6 +365,7 @@ size_t gesdd(T)( Slice!(T*) work, Slice!(lapackint*) iwork, ) + if(!isComplex!T) { lapackint m = cast(lapackint) a.length!1; lapackint n = cast(lapackint) a.length!0; @@ -371,10 +381,39 @@ size_t gesdd(T)( return info; } +/// ditto +size_t gesdd(T)( + char jobz, + Slice!(T*, 2, Canonical) a, + Slice!(realType!T*) s, + Slice!(T*, 2, Canonical) u, + Slice!(T*, 2, Canonical) vt, + Slice!(T*) work, + Slice!(realType!T*) rwork, + Slice!(lapackint*) iwork, + ) + if(isComplex!T) +{ + lapackint m = cast(lapackint) a.length!1; + lapackint n = cast(lapackint) a.length!0; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldu = cast(lapackint) u._stride.max(1); + lapackint ldvt = cast(lapackint) vt._stride.max(1); + lapackint lwork = cast(lapackint) work.length; + lapackint info = void; + + lapack.gesdd_(jobz, m, n, a.iterator, lda, s.iterator, u.iterator, ldu, vt.iterator, ldvt, work.iterator, lwork, rwork.iterator, iwork.iterator, info); + + assert(info >= 0); + return info; +} + unittest { alias s = gesdd!float; alias d = gesdd!double; + alias c = gesdd!cfloat; + alias z = gesdd!cdouble; } /// `gesvd` work space query From 836cd44cb99e78ee82fcf2b3a1db8fcf0a5fc6bb Mon Sep 17 00:00:00 2001 From: Kazuki Komatsu Date: Mon, 11 Feb 2019 12:04:13 +0900 Subject: [PATCH 3/4] Add complex version gesvd --- source/mir/lapack.d | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/source/mir/lapack.d b/source/mir/lapack.d index a416c7f..eba96cc 100644 --- a/source/mir/lapack.d +++ b/source/mir/lapack.d @@ -434,7 +434,14 @@ size_t gesvd_wq(T)( lapackint lwork = -1; lapackint info = void; - lapack.gesvd_(jobu, jobvt, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, info); + static if(!isComplex!T) + { + lapack.gesvd_(jobu, jobvt, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, info); + } + else + { + lapack.gesvd_(jobu, jobvt, m, n, null, lda, null, null, ldu, null, ldvt, &work, lwork, null, info); + } assert(info == 0); return cast(size_t) work; @@ -444,6 +451,8 @@ unittest { alias s = gesvd_wq!float; alias d = gesvd_wq!double; + alias c = gesvd_wq!cfloat; + alias z = gesvd_wq!cdouble; } /// @@ -456,6 +465,7 @@ size_t gesvd(T)( Slice!(T*, 2, Canonical) vt, Slice!(T*) work, ) + if(!isComplex!T) { lapackint m = cast(lapackint) a.length!1; lapackint n = cast(lapackint) a.length!0; @@ -471,10 +481,39 @@ size_t gesvd(T)( return info; } +/// ditto +size_t gesvd(T)( + char jobu, + char jobvt, + Slice!(T*, 2, Canonical) a, + Slice!(realType!T*) s, + Slice!(T*, 2, Canonical) u, + Slice!(T*, 2, Canonical) vt, + Slice!(T*) work, + Slice!(realType!T*) rwork, + ) + if(isComplex!T) +{ + lapackint m = cast(lapackint) a.length!1; + lapackint n = cast(lapackint) a.length!0; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldu = cast(lapackint) u._stride.max(1); + lapackint ldvt = cast(lapackint) vt._stride.max(1); + lapackint lwork = cast(lapackint) work.length; + lapackint info = void; + + lapack.gesvd_(jobu, jobvt, m, n, a.iterator, lda, s.iterator, u.iterator, ldu, vt.iterator, ldvt, work.iterator, lwork, rwork.iterator, info); + + assert(info >= 0); + return info; +} + unittest { alias s = gesvd!float; alias d = gesvd!double; + alias c = gesvd!cfloat; + alias z = gesvd!cdouble; } /// From 38a0ed119e01512bcf6ea374951e5828bccc969b Mon Sep 17 00:00:00 2001 From: Kazuki Komatsu Date: Mon, 11 Feb 2019 12:20:10 +0900 Subject: [PATCH 4/4] Add unittests for complex-type API --- source/mir/lapack.d | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/source/mir/lapack.d b/source/mir/lapack.d index eba96cc..2d63095 100644 --- a/source/mir/lapack.d +++ b/source/mir/lapack.d @@ -46,6 +46,8 @@ unittest { alias s = getri_wq!float; alias d = getri_wq!double; + alias c = getri_wq!cfloat; + alias z = getri_wq!cdouble; } /// @@ -74,6 +76,8 @@ unittest { alias s = getri!float; alias d = getri!double; + alias c = getri!cfloat; + alias z = getri!cdouble; } /// @@ -99,6 +103,8 @@ unittest { alias s = getrf!float; alias d = getrf!double; + alias c = getrf!cfloat; + alias z = getrf!cdouble; } /// @@ -600,6 +606,8 @@ unittest { alias s = sytrf!float; alias d = sytrf!double; + alias c = sytrf!cfloat; + alias z = sytrf!cdouble; } /// @@ -627,6 +635,8 @@ unittest { alias s = geqrf!float; alias d = geqrf!double; + alias c = geqrf!cfloat; + alias z = geqrf!cdouble; } /// @@ -658,6 +668,8 @@ unittest { alias s = getrs!float; alias d = getrs!double; + alias c = getrs!cfloat; + alias z = getrs!cdouble; } /// @@ -687,6 +699,8 @@ unittest { alias s = potrs!float; alias d = potrs!double; + alias c = potrs!cfloat; + alias z = potrs!cdouble; } /// @@ -718,6 +732,8 @@ unittest { alias s = sytrs2!float; alias d = sytrs2!double; + alias c = sytrs2!cfloat; + alias z = sytrs2!cdouble; } /// @@ -748,6 +764,8 @@ version(none) unittest { alias s = geqrs!float; alias d = geqrs!double; + alias c = geqrs!cfloat; + alias z = geqrs!cdouble; } /// @@ -777,6 +795,8 @@ unittest { alias s = sysv_rook_wk!float; alias d = sysv_rook_wk!double; + alias c = sysv_rook_wk!cfloat; + alias z = sysv_rook_wk!cdouble; } /// @@ -809,6 +829,8 @@ unittest { alias s = sysv_rook!float; alias d = sysv_rook!double; + alias c = sysv_rook!cfloat; + alias z = sysv_rook!cdouble; } /// @@ -950,6 +972,8 @@ unittest { alias s = potrf!float; alias d = potrf!double; + alias c = potrf!cfloat; + alias z = potrf!cdouble; } /// @@ -1024,6 +1048,8 @@ unittest { alias s = sptri!float; alias d = sptri!double; + alias c = sptri!cfloat; + alias z = sptri!cdouble; } /// @@ -1048,6 +1074,8 @@ unittest { alias s = potri!float; alias d = potri!double; + alias c = potri!cfloat; + alias z = potri!cdouble; } /// @@ -1088,6 +1116,8 @@ unittest { alias s = pptri!float; alias d = pptri!double; + alias c = pptri!cfloat; + alias z = pptri!cdouble; } /// @@ -1113,6 +1143,8 @@ unittest { alias s = trtri!float; alias d = trtri!double; + alias c = trtri!cfloat; + alias z = trtri!cdouble; } /// @@ -1156,6 +1188,8 @@ unittest { alias s = tptri!float; alias d = tptri!double; + alias c = tptri!cfloat; + alias z = tptri!cdouble; } ///