diff --git a/lapack/testlapack/dpttrf.go b/lapack/testlapack/dpttrf.go index 7b90a2db5..4cd80184b 100644 --- a/lapack/testlapack/dpttrf.go +++ b/lapack/testlapack/dpttrf.go @@ -10,6 +10,8 @@ import ( "testing" "golang.org/x/exp/rand" + + "gonum.org/v1/gonum/lapack" ) type Dpttrfer interface { @@ -75,19 +77,18 @@ func dpttrfResidual(n int, d, e, dFac, eFac []float64) float64 { } // Compute the 1-norm of the difference L*D*Lᵀ - A. - var anorm, resid float64 + var resid float64 if n == 1 { - anorm = d[0] resid = math.Abs(dDiff[0]) } else { - anorm = math.Max(d[0]+math.Abs(e[0]), d[n-1]+math.Abs(e[n-2])) resid = math.Max(math.Abs(dDiff[0])+math.Abs(eDiff[0]), math.Abs(dDiff[n-1])+math.Abs(eDiff[n-2])) for i := 1; i < n-1; i++ { - anorm = math.Max(anorm, d[i]+math.Abs(e[i])+math.Abs(e[i-1])) resid = math.Max(resid, math.Abs(dDiff[i])+math.Abs(eDiff[i-1])+math.Abs(eDiff[i])) } } + anorm := dlanst(lapack.MaxColumnSum, n, d, e) + // Compute norm(L*D*Lᵀ - A)/(n * norm(A)). if anorm == 0 { if resid != 0 { diff --git a/lapack/testlapack/locallapack.go b/lapack/testlapack/locallapack.go index 197330be5..3934b3411 100644 --- a/lapack/testlapack/locallapack.go +++ b/lapack/testlapack/locallapack.go @@ -677,3 +677,39 @@ func dlantb(norm lapack.MatrixNorm, uplo blas.Uplo, diag blas.Diag, n, k int, a } return value } + +func dlanst(norm lapack.MatrixNorm, n int, d, e []float64) float64 { + if n == 0 { + return 0 + } + var value float64 + switch norm { + case lapack.MaxAbs: + if n == 1 { + value = math.Abs(d[0]) + } else { + for _, di := range d[:n] { + value = math.Max(value, math.Abs(di)) + } + for _, ei := range e[:n-1] { + value = math.Max(value, math.Abs(ei)) + } + } + case lapack.MaxColumnSum, lapack.MaxRowSum: + if n == 1 { + value = math.Abs(d[0]) + } else { + value = math.Abs(d[0]) + math.Abs(e[0]) + value = math.Max(value, math.Abs(d[n-1])+math.Abs(e[n-2])) + for i := 1; i < n-1; i++ { + sum := math.Abs(d[i]) + math.Abs(e[i]) + math.Abs(e[i-1]) + value = math.Max(value, sum) + } + } + case lapack.Frobenius: + panic("not implemented") + default: + panic("invalid norm") + } + return value +}