diff --git a/dub.sdl b/dub.sdl index f6acbe8..e38c047 100644 --- a/dub.sdl +++ b/dub.sdl @@ -4,7 +4,7 @@ authors "Ilya Yaroshenko" copyright "Copyright © 2017-2018, Symmetry Investments & Kaleidic Associates" license "BSL-1.0" -dependency "lapack" version="~>0.0.6" +dependency "lapack" version="~>0.1.0" dependency "mir-blas" version=">=1.1.3 <2.0.0" configuration "library" { diff --git a/source/mir/lapack.d b/source/mir/lapack.d index 7fd769d..147fbcd 100644 --- a/source/mir/lapack.d +++ b/source/mir/lapack.d @@ -88,7 +88,7 @@ size_t getrf(T)( ) in { - assert(ipiv.length == min(a.length!0, a.length!1), "getrf: The length of 'ipiv' must equal the smaller of 'a''s dimensions"); + assert(ipiv.length >= min(a.length!0, a.length!1), "getrf: The length of 'ipiv' must be at least the smaller of 'a''s dimensions"); } do { @@ -1411,16 +1411,9 @@ size_t ungqr(T)( ) in { - assert(a.length!0 >= 0, "ungqr: The number of columns of 'a' must be " ~ - "greater than or equal to zero."); //n>=0 - assert(a.length!1 >= a.length!0, "ungqr: The number of columns of 'a' " ~ - "must be greater than or equal to the number of its rows."); //m>=n - assert(tau.length >= 0, "ungqr: The input 'tau' must have length greater " ~ - "than or equal to zero."); //k>=0 - assert(a.length!0 >= tau.length, "ungqr: The number of columns of 'a' " ~ - "must be greater than or equal to the length of 'tau'."); //n>=k - assert(work.length >= a.length!0, "ungqr: The length of 'work' must be " ~ - "greater than or equal to the number of rows of 'a'."); //lwork>=n + assert(a.length!1 >= a.length!0, "ungqr: The number of columns of 'a' must be greater than or equal to the number of its rows."); //m>=n + assert(a.length!0 >= tau.length, "ungqr: The number of columns of 'a' must be greater than or equal to the length of 'tau'."); //n>=k + assert(work.length >= a.length!0, "ungqr: The length of 'work' must be greater than or equal to the number of rows of 'a'."); //lwork>=n } do { @@ -1444,3 +1437,497 @@ unittest alias s = ungqr!cfloat; alias d = ungqr!cdouble; } + +alias orghr = unghr; // this is the name for the real type vairant of ungqr + +/// +size_t unghr(T)( + Slice!(T*, 2, Canonical) a, + Slice!(T*) tau, + Slice!(T*) work, +) +in +{ + assert(a.length!1 >= a.length!0); //m>=n + assert(a.length!0 >= tau.length); //n>=k + assert(work.length >= a.length!0); //lwork>=n +} +do +{ + lapackint m = cast(lapackint) a.length!1; + lapackint n = cast(lapackint) a.length!0; + lapackint k = cast(lapackint) tau.length; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint lwork = cast(lapackint) work.length; + lapackint info = void; + static if (isComplex!T){ + lapack.ungqr_(m, n, k, a.iterator, lda, tau.iterator, work.iterator, lwork, info); + } + else { + lapack.orgqr_(m, n, k, a.iterator, lda, tau.iterator, work.iterator, lwork, info); + } + + ///if info == 0: successful exit. + ///if info < 0: if info == -i, the i-th argument had an illegal value. + assert(info >= 0); + return cast(size_t)info; +} + +unittest +{ + alias orghrf = orghr!float; + alias orghrd = orghr!double; + alias unghrf = unghr!float; + alias unghrd = unghr!double; + alias unghrcf = unghr!cfloat; + alias unghrcd = unghr!cdouble; +} + +/// +size_t gehrd(T)( + Slice!(T*, 2, Canonical) a, + Slice!(T*) tau, + Slice!(T*) work, + lapackint* ilo, + lapackint* ihi +) +in +{ + assert(a.length!1 >= a.length!0, "gehrd: The number of columns of 'a' must be greater than or equal to the number of its rows."); //m>=n + assert(a.length!0 >= tau.length, "gehrd: The number of columns of 'a' must be greater than or equal to the length of 'tau'."); //n>=k + assert(work.length >= a.length!0, "gehrd: The length of 'work' must be greater than or equal to the number of rows of 'a'."); //lwork>=n +} +do +{ + lapackint n = cast(lapackint) a.length!0; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint lwork = cast(lapackint) work.length; + lapackint info = void; + lapack.gehrd_(n, ilo, ihi, a.iterator, lda, tau.iterator, work.iterator, lwork, info); + ///if info == 0: successful exit. + ///if info < 0: if info == -i, the i-th argument had an illegal value. + assert(info >= 0); + return cast(size_t)info; +} + +unittest +{ + alias s = gehrd!cfloat; + alias d = gehrd!cdouble; +} + +size_t hsein(T)( + char side, + char eigsrc, + char initv, + lapackint* select, //actually a logical bitset stored in here + Slice!(T*, 2, Canonical) h, + Slice!(T*) wr, + Slice!(T*) wi, + Slice!(T*, 2, Canonical) vl, + Slice!(T*, 2, Canonical) vr, + lapackint* m, + Slice!(T*) work, + lapackint* ifaill, + lapackint* ifailr, + lapackint* ilo, + lapackint* ihi, +) + if (!isComplex!T) +in +{ + assert(h.length!1 >= h.length!0, "hsein: The number of columns of 'h' " ~ + "must be greater than or equal to the number of its rows."); //m>=n + assert(wr.length >= 1, "hsein: The input 'wr' must have length greater " ~ + "than or equal to one."); + assert(wr.length >= h.length!0, "hsein: The input 'wr' must have length greater " ~ + "than or equal to the number of rows of 'h'."); + assert(wr.length >= 1.0, "hsein: The input 'wr' must have length greater " ~ + "than or equal to 1."); + assert(wi.length >= 1, "hsein: The input 'wi' must have length greater " ~ + "than or equal to one."); + assert(wi.length >= h.length!0, "hsein: The input 'wi' must have length greater " ~ + "than or equal to the number of rows of 'h'."); + assert(wi.length >= 1.0, "hsein: The input 'wi' must have length greater " ~ + "than or equal to 1."); + assert(work.length >= h.length!0 * (h.length!0 + 2), "hsein: The length of 'work' must be " ~ + "greater than or equal to the square of the number of rows of 'h' plus two additional rows for real types."); + assert(side == 'R' || side == 'L' || side == 'B', "hsein: The char, 'side' must be " ~ + "one of 'R', 'L' or 'B'."); + assert(eigsrc == 'Q' || eigsrc == 'N', "hsein: The char, 'eigsrc', must be " ~ + "one of 'Q' or 'R'."); + assert(initv == 'N' || initv == 'U', "hsein: The char, 'initv', must be " ~ + "one of 'N' or 'U'."); + assert(side != 'L' || side != 'B' || vl.length!1 >= 1, "hsein: Slice 'vl' must be" ~ + "at least the size of '1' when 'side' is set to 'L' or 'B'."); + assert(side != 'R' || vl.length!1 >= 1, "hsein: Slice 'vl' must be" ~ + "length greater than 1 when 'side' is 'R'."); + assert(side != 'R' || side != 'B' || vr.length!1 >= 1, "hsein: Slice 'vr' must be" ~ + "at least the size of '1' when 'side' is set to 'R' or 'B'."); + assert(side != 'L' || vl.length!1 >= 1, "hsein: Slice 'vr' must be" ~ + "length greater than 1 when 'side' is 'L'."); +} +do +{ + lapackint info; + lapackint mm = cast(lapackint) vl.length!1; + lapackint n = cast(lapackint) h.length!0; + lapackint ldh = cast(lapackint) h._stride.max(1); + lapackint ldvl = cast(lapackint) vl._stride.max(1); + lapackint ldvr = cast(lapackint) vr._stride.max(1); + //need to seperate these methods then probably provide a wrap which does this as that's the easiest way without bloating the base methods + lapack.hsein_(side, eigsrc, initv, select, n, h.iterator, ldh, wr.iterator, wi.iterator, vl.iterator, ldvl, vr.iterator, ldvr, mm, *m, work.iterator, ifaill, ifailr, info); + assert(info >= 0); + ///if any of ifaill or ifailr entries are non-zero then that has failed to converge. + ///ifail?[i] = j > 0 if the eigenvector stored in the i-th column of v?, coresponding to the jth eigenvalue, fails to converge. + assert(*ifaill == 0); + assert(*ifailr == 0); + return info; +} + +size_t hsein(T, realT)( + char side, + char eigsrc, + char initv, + lapackint* select, //actually a logical bitset stored in here + Slice!(T*, 2, Canonical) h, + Slice!(T*) w, + Slice!(T*, 2, Canonical) vl, + Slice!(T*, 2, Canonical) vr, + lapackint* m, + Slice!(T*) work, + Slice!(realT*) rwork, + lapackint* ifaill, + lapackint* ifailr, + lapackint* ilo, + lapackint* ihi, +) + if (isComplex!T && is(realType!T == realT)) +in +{ + assert(h.length!1 >= h.length!0, "hsein: The number of columns of 'h' " ~ + "must be greater than or equal to the number of its rows."); //m>=n + assert(w.length >= 1, "hsein: The input 'w' must have length greater " ~ + "than or equal to one."); + assert(w.length >= h.length!0, "hsein: The input 'w' must have length greater " ~ + "than or equal to the number of rows of 'h'."); + assert(w.length >= 1.0, "hsein: The input 'w' must have length greater " ~ + "than or equal to 1."); + assert(work.length >= h.length!0 * h.length!0, "hsein: The length of 'work' must be " ~ + "greater than or equal to the square of the number of rows of 'h' for complex types."); + assert(side == 'R' || side == 'L' || side == 'B', "hsein: The char, 'side' must be " ~ + "one of 'R', 'L' or 'B'."); + assert(eigsrc == 'Q' || eigsrc == 'N', "hsein: The char, 'eigsrc', must be " ~ + "one of 'Q' or 'R'."); + assert(initv == 'N' || initv == 'U', "hsein: The char, 'initv', must be " ~ + "one of 'N' or 'U'."); + assert(side != 'L' || side != 'B' || vl.length!1 >= 1, "hsein: Slice 'vl' must be" ~ + "at least the size of '1' when 'side' is set to 'L' or 'B'."); + assert(side != 'R' || vl.length!1 >= 1, "hsein: Slice 'vl' must be" ~ + "length greater than 1 when 'side' is 'R'."); + assert(side != 'R' || side != 'B' || vr.length!1 >= 1, "hsein: Slice 'vr' must be" ~ + "at least the size of '1' when 'side' is set to 'R' or 'B'."); + assert(side != 'L' || vl.length!1 >= 1, "hsein: Slice 'vr' must be" ~ + "length greater than 1 when 'side' is 'L'."); +} +do { + lapackint n = cast(lapackint) h.length!0; + lapackint ldh = cast(lapackint) h._stride.max(1); + lapackint ldvl = cast(lapackint) vl._stride.max(1); + lapackint ldvr = cast(lapackint) vr._stride.max(1); + lapackint mm = cast(lapackint) vl.length!1; + lapackint info = void; + //could compute mm and m from vl and/or vr and T + lapack.hsein_(side, eigsrc, initv, select, n, h.iterator, ldh, w.iterator, vl.iterator, ldvl, vr.iterator, ldvr, mm, *m, work.iterator, rwork.iterator, ifaill, ifailr, info); + assert(info >= 0); + ///if any of ifaill or ifailr entries are non-zero then that has failed to converge. + ///ifail?[i] = j > 0 if the eigenvector stored in the i-th column of v?, coresponding to the jth eigenvalue, fails to converge. + assert(*ifaill == 0); + assert(*ifailr == 0); + return info; +} + + +unittest +{ + alias f = hsein!(float); + alias d = hsein!(double); + alias s = hsein!(cfloat,float); + alias c = hsein!(cdouble,double); +} + +alias ormhr = unmhr; + +/// +size_t unmhr(T)( + char side, + char trans, + Slice!(T*, 2, Canonical) a, + Slice!(T*) tau, + Slice!(T*, 2, Canonical) c, + Slice!(T*) work, + lapackint* ilo, + lapackint* ihi +) +in +{ + assert(a.length!0 >= 0, "ormhr: The number of columns of 'a' must be " ~ + "greater than or equal to zero."); //n>=0 + assert(a.length!1 >= a.length!0, "ormhr: The number of columns of 'a' " ~ + "must be greater than or equal to the number of its rows."); //m>=n + assert(c.length!0 >= 0, "ormhr: The number of columns of 'c' must be " ~ + "greater than or equal to zero."); //n>=0 + assert(c.length!1 >= c.length!0, "ormhr: The number of columns of 'c' " ~ + "must be greater than or equal to the number of its rows."); //m>=n + assert(tau.length >= 0, "ormhr: The input 'tau' must have length greater " ~ + "than or equal to zero."); //k>=0 + assert(a.length!0 >= tau.length, "ormhr: The number of columns of 'a' " ~ + "must be greater than or equal to the length of 'tau'."); //n>=k + assert(work.length >= a.length!0, "ormhr: The length of 'work' must be " ~ + "greater than or equal to the number of rows of 'a'."); //lwork>=n + assert(side == 'L' || side == 'R', "ormhr: 'side' must be" ~ + "one of 'L' or 'R'."); + assert(trans == 'N' || trans == 'T', "ormhr: 'trans' must be" ~ + "one of 'N' or 'T'."); +} +do +{ + lapackint m = cast(lapackint) a.length!0; + lapackint n = cast(lapackint) a.length!1; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldc = cast(lapackint) c._stride.max(1); + lapackint lwork = cast(lapackint) work.length; + lapackint info = void; + static if (!isComplex!T){ + lapack.ormhr_(side, trans, m, n, ilo, ihi, a.iterator, lda, tau.iterator, c.iterator, ldc, work.iterator, lwork, info); + } + else { + lapack.unmhr_(side, trans, m, n, ilo, ihi, a.iterator, lda, tau.iterator, c.iterator, ldc, work.iterator, lwork, info); + } + ///if info == 0: successful exit. + ///if info < 0: if info == -i, the i-th argument had an illegal value. + assert(info >= 0); + return cast(size_t)info; +} + +unittest +{ + alias s = unmhr!cfloat; + alias d = unmhr!cdouble; + alias a = ormhr!double; + alias b = ormhr!float; +} + +size_t hseqr(T)( + char job, + char compz, + Slice!(T*, 2, Canonical) h, + Slice!(T*) w, + Slice!(T*, 2, Canonical) z, + Slice!(T*) work, + lapackint* ilo, + lapackint* ihi +) + if (isComplex!T) +in +{ + assert(job == 'E' || job == 'S', "hseqr"); + assert(compz == 'N' || compz == 'I' || compz == 'V', "hseqr"); + assert(h.length!1 >= h.length!0, "hseqr"); + assert(h.length!1 >= 1, "hseqr"); + assert(compz != 'V' || compz != 'I' || (z.length!1 >= h.length!0 && z.length!1 >= 1), "hseqr"); + assert(compz != 'N' || z.length!1 >= 1); + assert(work.length!0 >= 1, "hseqr"); + assert(work.length!0 >= h.length!0, "hseqr"); +} +do +{ + lapackint n = cast(lapackint) h.length!0; + lapackint ldh = cast(lapackint) h._stride.max(1); + lapackint ldz = cast(lapackint) z._stride.max(1); + lapackint lwork = cast(lapackint) work.length!0; + lapackint info; + lapack.hseqr_(job,compz,n,ilo,ihi,h.iterator, ldh, w.iterator, z.iterator, ldz, work.iterator, lwork, info); + assert(info >= 0); + return cast(size_t)info; +} + +size_t hseqr(T)( + char job, + char compz, + Slice!(T*, 2, Canonical) h, + Slice!(T*) wr, + Slice!(T*) wi, + Slice!(T*, 2, Canonical) z, + Slice!(T*) work, + lapackint* ilo, + lapackint* ihi +) + if (!isComplex!T) +in +{ + assert(job == 'E' || job == 'S', "hseqr"); + assert(compz == 'N' || compz == 'I' || compz == 'V', "hseqr"); + assert(h.length!1 >= h.length!0, "hseqr"); + assert(h.length!1 >= 1, "hseqr"); + assert(compz != 'V' || compz != 'I' || (z.length!1 >= h.length!0 && z.length!1 >= 1), "hseqr"); + assert(compz != 'N' || z.length!1 >= 1); + assert(work.length!0 >= 1, "hseqr"); + assert(work.length!0 >= h.length!0, "hseqr"); +} +do +{ + lapackint n = cast(lapackint) h.length!0; + lapackint ldh = cast(lapackint) h._stride.max(1); + lapackint ldz = cast(lapackint) z._stride.max(1); + lapackint lwork = cast(lapackint) work.length!0; + lapackint info; + lapack.hseqr_(job,compz,n,ilo,ihi,h.iterator, ldh, wr.iterator, wi.iterator, z.iterator, ldz, work.iterator, lwork, info); + assert(info >= 0); + return cast(size_t)info; +} + +unittest +{ + alias f = hseqr!float; + alias d = hseqr!double; + alias s = hseqr!cfloat; + alias c = hseqr!cdouble; +} + +size_t trevc(T)(char side, + char howmany, + lapackint select, + Slice!(T*, 2, Canonical) t, + Slice!(T*, 2, Canonical) vl, + Slice!(T*, 2, Canonical) vr, + lapackint* m, + Slice!(T*) work +) +do +{ + lapackint n = cast(lapackint)t.length!0; + lapackint ldt = cast(lapackint) t._stride.max(1); + lapackint ldvl = cast(lapackint) vl._stride.max(1); + lapackint ldvr = cast(lapackint) vr._stride.max(1); + lapackint mm = cast(lapackint) vr.length!1; + //select should be lapack_logical + lapackint info; + static if(!isComplex!T){ + lapack.trevc_(side, howmany, &select, n, t.iterator, ldt, vl.iterator, ldvl, vr.iterator, ldvr, mm, *m, work.iterator, info); + } + else { + lapack.trevc_(side, howmany, &select, n, t.iterator, ldt, vl.iterator, ldvl, vr.iterator, ldvr, mm, *m, work.iterator, null, info); + } + assert(info >= 0); + return cast(size_t)info; +} + +unittest +{ + alias f = trevc!float; + alias d = trevc!double; + alias s = trevc!cfloat; + alias c = trevc!cdouble; +} + +alias complexType(T : double) = cdouble; +alias complexType(T : float) = cfloat; +alias complexType(T : real) = creal; +alias complexType(T : isComplex!T) = T; + +size_t gebal(T, realT)(char job, + Slice!(T*, 2, Canonical) a, + lapackint* ilo, + lapackint* ihi, + Slice!(realT*) scale +) + if (!isComplex!T || (isComplex!T && is(realType!T == realT))) +{ + lapackint n = cast(lapackint) a.length!0; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint info = void; + lapack.gebal_(job, n, a.iterator, lda, ilo, ihi, scale.iterator, info); + assert(info >= 0); + return cast(size_t)info; +} + +unittest +{ + alias a = gebal!(double,double); + alias b = gebal!(cdouble,double); + alias c = gebal!(float,float); + alias d = gebal!(cfloat,float); +} + +size_t gebak(T, realT)( + char job, + char side, + lapackint* ilo, + lapackint* ihi, + Slice!(realT*) scale, + Slice!(T*, 2, Canonical) v +) + if (!isComplex!T || (isComplex!T && is(realType!T == realT))) +{ + lapackint n = cast(lapackint) scale.length!0; + lapackint m = cast(lapackint) v.length!1;//num evects + lapackint ldv = cast(lapackint) v._stride.max(1); + lapackint info = void; + lapack.gebak_(job, side, n, ilo, ihi, scale.iterator, m, v.iterator, ldv, info); + assert(info >= 0); + return cast(size_t)info; +} + +unittest +{ + alias a = gebak!(double,double); + alias b = gebak!(cdouble,double); + alias c = gebak!(float,float); + alias d = gebak!(cfloat,float); +} + +size_t geev(T, realT)( + char jobvl, + char jobvr, + Slice!(T*, 2, Canonical) a, + Slice!(T*) w, + Slice!(T*, 2, Canonical) vl, + Slice!(T*, 2, Canonical) vr, + Slice!(T*) work, + Slice!(realT*) rwork +) + if (isComplex!T && is(realType!T == realT)) +{ + lapackint n = cast(lapackint) a.length!0; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldvr = cast(lapackint) vr._stride.max(1); + lapackint ldvl = cast(lapackint) vl._stride.max(1); + lapackint info = void; + lapackint lwork = cast(lapackint)work.length!0; + lapack.geev_(jobvl, jobvr, n, a.iterator, lda, w.iterator, vl.iterator, ldvl, vr.iterator, ldvr, work.iterator, lwork, rwork.iterator, info); + assert(info >= 0); + return info; +} +size_t geev(T)( + char jobvl, + char jobvr, + Slice!(T*, 2, Canonical) a, + Slice!(T*) wr, + Slice!(T*) wi, + Slice!(T*, 2, Canonical) vl, + Slice!(T*, 2, Canonical) vr, + Slice!(T*) work +) + if (!isComplex!T) +{ + lapackint n = cast(lapackint) a.length!0; + lapackint lda = cast(lapackint) a._stride.max(1); + lapackint ldvr = cast(lapackint) vr._stride.max(1); + lapackint ldvl = cast(lapackint) vl._stride.max(1); + lapackint info = void; + lapackint lwork = cast(lapackint)work.length!0; + lapack.geev_(jobvl, jobvr, n, a.iterator, lda, wr.iterator, wi.iterator, vl.iterator, ldvl, vr.iterator, ldvr, work.iterator, lwork, info); + assert(info >= 0); + return info; +}