Skip to content

Commit

Permalink
lapack/gonum: add Dptsv
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-ch committed Nov 23, 2023
1 parent 44d84c9 commit 3462e90
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 0 deletions.
49 changes: 49 additions & 0 deletions lapack/gonum/dptsv.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright ©2023 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package gonum

// Dptsv computes the solution to system of linear equations
//
// A * X = B
//
// where A is an n×n symmetric positive definite tridiagonal matrix, and X and B
// are n×nrhs matrices. A is factored as A = L*D*Lᵀ, and the factored form of A
// is then used to solve the system of equations.
//
// On entry, d contains the n diagonal elements of A and e contains the (n-1)
// subdiagonal elements of A. On return, d contains the n diagonal elements of
// the diagonal matrix D from the factorization A = L*D*Lᵀ and e contains the
// (n-1) subdiagonal elements of the unit bidiagonal factor L.
//
// Dptsv returns whether the solution X has been successfully computed.
func (impl Implementation) Dptsv(n, nrhs int, d, e []float64, b []float64, ldb int) (ok bool) {
switch {
case n < 0:
panic(nLT0)
case nrhs < 0:
panic(nrhsLT0)
case ldb < max(1, nrhs):
panic(badLdB)
}

if n == 0 || nrhs == 0 {
return true
}

switch {
case len(d) < n:
panic(shortD)
case len(e) < n-1:
panic(shortE)
case len(b) < (n-1)*ldb+nrhs:
panic(shortB)
}

ok = impl.Dpttrf(n, d, e)
if ok {
impl.Dpttrs(n, nrhs, d, e, b, ldb)
}
return ok
}
5 changes: 5 additions & 0 deletions lapack/gonum/lapack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,11 @@ func TestDpttrs(t *testing.T) {
testlapack.DpttrsTest(t, impl)
}

func TestDptsv(t *testing.T) {
t.Parallel()
testlapack.DptsvTest(t, impl)
}

func TestDrscl(t *testing.T) {
t.Parallel()
testlapack.DrsclTest(t, impl)
Expand Down
55 changes: 55 additions & 0 deletions lapack/testlapack/dptsv.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright ©2023 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package testlapack

import (
"fmt"
"testing"

"golang.org/x/exp/rand"
)

type Dptsver interface {
Dptsv(n, nrhs int, d, e []float64, b []float64, ldb int) (ok bool)
}

func DptsvTest(t *testing.T, impl Dptsver) {
rnd := rand.New(rand.NewSource(1))
for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 20, 50, 51, 52, 53, 54, 100} {
for _, nrhs := range []int{0, 1, 2, 3, 4, 5, 10, 20, 50} {
for _, ldb := range []int{max(1, nrhs), nrhs + 3} {
dptsvTest(t, impl, rnd, n, nrhs, ldb)
}
}
}
}

func dptsvTest(t *testing.T, impl Dptsver, rnd *rand.Rand, n, nrhs, ldb int) {
const tol = 1e-15

name := fmt.Sprintf("n=%v", n)

// Generate a random diagonally dominant symmetric tridiagonal matrix A.
d, e := newRandomSymTridiag(n, rnd)

// Generate a random solution matrix X.
xWant := randomGeneral(n, nrhs, ldb, rnd)

// Compute the right-hand side.
b := zeros(n, nrhs, ldb)
dstmm(n, nrhs, d, e, xWant.Data, xWant.Stride, b.Data, b.Stride)

// Solve A*X=B.
ok := impl.Dptsv(n, nrhs, d, e, b.Data, b.Stride)
if !ok {
t.Errorf("%v: Dptsv failed", name)
return
}

resid := dpttrsResidual(b, xWant)
if resid > tol {
t.Errorf("%v: unexpected solution: |diff| = %v, want <= %v", name, resid, tol)
}
}

0 comments on commit 3462e90

Please sign in to comment.