Skip to content

Commit

Permalink
blas/gonum: add Zher2k with test
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-ch committed Jan 10, 2019
1 parent da5b038 commit 0c326c0
Show file tree
Hide file tree
Showing 4 changed files with 372 additions and 3 deletions.
3 changes: 0 additions & 3 deletions blas/gonum/cmplx.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,3 @@ func (Implementation) Ztrsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas
func (Implementation) Zhemm(s blas.Side, ul blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
panic(noComplex)
}
func (Implementation) Zher2k(ul blas.Uplo, t blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int) {
panic(noComplex)
}
194 changes: 194 additions & 0 deletions blas/gonum/level3cmplx128.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,200 @@ func (Implementation) Zherk(uplo blas.Uplo, trans blas.Transpose, n, k int, alph
}
}

// Zher2k performs one of the hermitian rank-2k operations
// C = alpha*A*B^H + conj(alpha)*B*A^H + beta*C if trans == blas.NoTrans
// C = alpha*A^H*B + conj(alpha)*B^H*A + beta*C if trans == blas.ConjTrans
// where alpha and beta are scalars with beta real, C is an n×n hermitian matrix
// and A and B are n×k matrices in the first case and k×n matrices in the second case.
//
// The imaginary parts of the diagonal elements of C are assumed to be zero, and
// on return they will be set to zero.
func (Implementation) Zher2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int) {
var row, col int
switch trans {
default:
panic(badTranspose)
case blas.NoTrans:
row, col = n, k
case blas.ConjTrans:
row, col = k, n
}
switch {
case uplo != blas.Lower && uplo != blas.Upper:
panic(badUplo)
case n < 0:
panic(nLT0)
case k < 0:
panic(kLT0)
case lda < max(1, col):
panic(badLdA)
case ldb < max(1, col):
panic(badLdB)
case ldc < max(1, n):
panic(badLdC)
}

// Quick return if possible.
if n == 0 {
return
}

// For zero matrix size the following slice length checks are trivially satisfied.
if len(a) < (row-1)*lda+col {
panic(shortA)
}
if len(b) < (row-1)*ldb+col {
panic(shortB)
}
if len(c) < (n-1)*ldc+n {
panic(shortC)
}

// Quick return if possible.
if (alpha == 0 || k == 0) && beta == 1 {
return
}

if alpha == 0 {
if uplo == blas.Upper {
if beta == 0 {
for i := 0; i < n; i++ {
ci := c[i*ldc+i : i*ldc+n]
for j := range ci {
ci[j] = 0
}
}
} else {
for i := 0; i < n; i++ {
ci := c[i*ldc+i : i*ldc+n]
ci[0] = complex(beta*real(ci[0]), 0)
if i != n-1 {
c128.DscalUnitary(beta, ci[1:])
}
}
}
} else {
if beta == 0 {
for i := 0; i < n; i++ {
ci := c[i*ldc : i*ldc+i+1]
for j := range ci {
ci[j] = 0
}
}
} else {
for i := 0; i < n; i++ {
ci := c[i*ldc : i*ldc+i+1]
if i != 0 {
c128.DscalUnitary(beta, ci[:i])
}
ci[i] = complex(beta*real(ci[i]), 0)
}
}
}
return
}

conjalpha := cmplx.Conj(alpha)
cbeta := complex(beta, 0)
if trans == blas.NoTrans {
// Form C = alpha*A*B^H + conj(alpha)*B*A^H + beta*C.
if uplo == blas.Upper {
for i := 0; i < n; i++ {
ci := c[i*ldc+i+1 : i*ldc+n]
ai := a[i*lda : i*lda+k]
bi := b[i*ldb : i*ldb+k]
if beta == 0 {
cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi)
c[i*ldc+i] = complex(real(cii), 0)
for jc := range ci {
j := i + 1 + jc
ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi)
}
} else {
cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i]
c[i*ldc+i] = complex(real(cii), 0)
for jc, cij := range ci {
j := i + 1 + jc
ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij
}
}
}
} else {
for i := 0; i < n; i++ {
ci := c[i*ldc : i*ldc+i]
ai := a[i*lda : i*lda+k]
bi := b[i*ldb : i*ldb+k]
if beta == 0 {
for j := range ci {
ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi)
}
cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi)
c[i*ldc+i] = complex(real(cii), 0)
} else {
for j, cij := range ci {
ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij
}
cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i]
c[i*ldc+i] = complex(real(cii), 0)
}
}
}
} else {
// Form C = alpha*A^H*B + conj(alpha)*B^H*A + beta*C.
if uplo == blas.Upper {
for i := 0; i < n; i++ {
ci := c[i*ldc+i : i*ldc+n]
if beta == 0 {
for jc := range ci {
ci[jc] = 0
}
} else if beta != 1 {
c128.DscalUnitary(beta, ci)
ci[0] = complex(real(ci[0]), 0)
} else {
ci[0] = complex(real(ci[0]), 0)
}
for j := 0; j < k; j++ {
aji := a[j*lda+i]
bji := b[j*ldb+i]
if aji != 0 {
c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb+i:j*ldb+n], ci)
}
if bji != 0 {
c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda+i:j*lda+n], ci)
}
}
ci[0] = complex(real(ci[0]), 0)
}
} else {
for i := 0; i < n; i++ {
ci := c[i*ldc : i*ldc+i+1]
if beta == 0 {
for j := range ci {
ci[j] = 0
}
} else if beta != 1 {
c128.DscalUnitary(beta, ci)
ci[i] = complex(real(ci[i]), 0)
} else {
ci[i] = complex(real(ci[i]), 0)
}
for j := 0; j < k; j++ {
aji := a[j*lda+i]
bji := b[j*ldb+i]
if aji != 0 {
c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb:j*ldb+i+1], ci)
}
if bji != 0 {
c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda:j*lda+i+1], ci)
}
}
ci[i] = complex(real(ci[i]), 0)
}
}
}
}

// Zsyrk performs one of the symmetric rank-k operations
// C = alpha*A*A^T + beta*C if trans == blas.NoTrans
// C = alpha*A^T*A + beta*C if trans == blas.Trans
Expand Down
1 change: 1 addition & 0 deletions blas/gonum/level3cmplx128_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ import (

func TestZgemm(t *testing.T) { testblas.ZgemmTest(t, impl) }
func TestZherk(t *testing.T) { testblas.ZherkTest(t, impl) }
func TestZher2k(t *testing.T) { testblas.Zher2kTest(t, impl) }
func TestZsyrk(t *testing.T) { testblas.ZsyrkTest(t, impl) }
func TestZsyr2k(t *testing.T) { testblas.Zsyr2kTest(t, impl) }
177 changes: 177 additions & 0 deletions blas/testblas/zher2k.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Copyright ©2019 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package testblas

import (
"fmt"
"math/cmplx"
"testing"

"golang.org/x/exp/rand"
"gonum.org/v1/gonum/blas"
)

type Zher2ker interface {
Zher2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int)
}

func Zher2kTest(t *testing.T, impl Zher2ker) {
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
for _, trans := range []blas.Transpose{blas.NoTrans, blas.ConjTrans} {
name := uploString(uplo) + "-" + transString(trans)
t.Run(name, func(t *testing.T) {
for _, n := range []int{0, 1, 2, 3, 4, 5} {
for _, k := range []int{0, 1, 2, 3, 4, 5, 7} {
zher2kTest(t, impl, uplo, trans, n, k)
}
}
})
}
}
}

func zher2kTest(t *testing.T, impl Zher2ker, uplo blas.Uplo, trans blas.Transpose, n, k int) {
const tol = 1e-13

rnd := rand.New(rand.NewSource(1))

row, col := n, k
if trans == blas.ConjTrans {
row, col = k, n
}
for _, lda := range []int{max(1, col), col + 2} {
for _, ldb := range []int{max(1, col), col + 3} {
for _, ldc := range []int{max(1, n), n + 4} {
for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} {
for _, beta := range []float64{0, 1, 1.3} {
// Allocate the matrix A and fill it with random numbers.
a := make([]complex128, row*lda)
for i := range a {
a[i] = rndComplex128(rnd)
}
// Create a copy of A for checking that
// Zher2k does not modify A.
aCopy := make([]complex128, len(a))
copy(aCopy, a)

// Allocate the matrix B and fill it with random numbers.
b := make([]complex128, row*ldb)
for i := range b {
b[i] = rndComplex128(rnd)
}
// Create a copy of B for checking that
// Zher2k does not modify B.
bCopy := make([]complex128, len(b))
copy(bCopy, b)

// Allocate the matrix C and fill it with random numbers.
c := make([]complex128, n*ldc)
for i := range c {
c[i] = rndComplex128(rnd)
}
if (alpha == 0 || k == 0) && beta == 1 {
// In case of a quick return
// zero out the diagonal.
for i := 0; i < n; i++ {
c[i*ldc+i] = complex(real(c[i*ldc+i]), 0)
}
}
// Create a copy of C for checking that
// Zher2k does not modify its triangle
// opposite to uplo.
cCopy := make([]complex128, len(c))
copy(cCopy, c)
// Create a copy of C expanded into a
// full hermitian matrix for computing
// the expected result using zmm.
cHer := make([]complex128, len(c))
copy(cHer, c)
if uplo == blas.Upper {
for i := 0; i < n; i++ {
cHer[i*ldc+i] = complex(real(cHer[i*ldc+i]), 0)
for j := i + 1; j < n; j++ {
cHer[j*ldc+i] = cmplx.Conj(cHer[i*ldc+j])
}
}
} else {
for i := 0; i < n; i++ {
for j := 0; j < i; j++ {
cHer[j*ldc+i] = cmplx.Conj(cHer[i*ldc+j])
}
cHer[i*ldc+i] = complex(real(cHer[i*ldc+i]), 0)
}
}

// Compute the expected result using an internal Zgemm implementation.
var want []complex128
if trans == blas.NoTrans {
// C = alpha*A*B^H + conj(alpha)*B*A^H + beta*C
tmp := zmm(blas.NoTrans, blas.ConjTrans, n, n, k, alpha, a, lda, b, ldb, complex(beta, 0), cHer, ldc)
want = zmm(blas.NoTrans, blas.ConjTrans, n, n, k, cmplx.Conj(alpha), b, ldb, a, lda, 1, tmp, ldc)
} else {
// C = alpha*A^H*B + conj(alpha)*B^H*A + beta*C
tmp := zmm(blas.ConjTrans, blas.NoTrans, n, n, k, alpha, a, lda, b, ldb, complex(beta, 0), cHer, ldc)
want = zmm(blas.ConjTrans, blas.NoTrans, n, n, k, cmplx.Conj(alpha), b, ldb, a, lda, 1, tmp, ldc)
}

// Compute the result using Zher2k.
impl.Zher2k(uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc)

prefix := fmt.Sprintf("n=%v,k=%v,lda=%v,ldb=%v,ldc=%v,alpha=%v,beta=%v", n, k, lda, ldb, ldc, alpha, beta)

if !zsame(a, aCopy) {
t.Errorf("%v: unexpected modification of A", prefix)
continue
}
if !zsame(b, bCopy) {
t.Errorf("%v: unexpected modification of B", prefix)
continue
}
if uplo == blas.Upper && !zSameLowerTri(n, c, ldc, cCopy, ldc) {
t.Errorf("%v: unexpected modification in lower triangle of C", prefix)
continue
}
if uplo == blas.Lower && !zSameUpperTri(n, c, ldc, cCopy, ldc) {
t.Errorf("%v: unexpected modification in upper triangle of C", prefix)
continue
}

// Check that the diagonal of C has only real elements.
hasRealDiag := true
for i := 0; i < n; i++ {
if imag(c[i*ldc+i]) != 0 {
hasRealDiag = false
break
}
}
if !hasRealDiag {
t.Errorf("%v: diagonal of C has imaginary elements\ngot=%v", prefix, c)
continue
}

// Expand C into a full hermitian matrix
// for comparison with the result from zmm.
if uplo == blas.Upper {
for i := 0; i < n-1; i++ {
for j := i + 1; j < n; j++ {
c[j*ldc+i] = cmplx.Conj(c[i*ldc+j])
}
}
} else {
for i := 1; i < n; i++ {
for j := 0; j < i; j++ {
c[j*ldc+i] = cmplx.Conj(c[i*ldc+j])
}
}
}
if !zEqualApprox(c, want, tol) {
t.Errorf("%v: unexpected result\nwant=%v\ngot= %v", prefix, want, c)
}
}
}
}
}
}
}

0 comments on commit 0c326c0

Please sign in to comment.