Skip to content

Commit

Permalink
mat: generalise MulVec
Browse files Browse the repository at this point in the history
  • Loading branch information
kortschak committed Dec 21, 2017
1 parent e0ca1bd commit e024721
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 52 deletions.
113 changes: 63 additions & 50 deletions mat/vector.go
Expand Up @@ -503,105 +503,118 @@ func (v *VecDense) DivElemVec(a, b Vector) {

// MulVec computes a * b. The result is stored into the receiver.
// MulVec panics if the number of columns in a does not equal the number of rows in b.
func (v *VecDense) MulVec(a Matrix, b *VecDense) {
func (v *VecDense) MulVec(a Matrix, b Vector) {
r, c := a.Dims()
br := b.Len()
br, _ := b.Dims()
if c != br {
panic(ErrShape)
}

if v != b {
v.checkOverlap(b.mat)
aU, trans := untranspose(a)
var bmat blas64.Vector
fast := true
bU, _ := untranspose(b)
if rv, ok := bU.(RawVectorer); ok {
bmat = rv.RawVector()
if v != b {
v.checkOverlap(bmat)
}
} else {
fast = false
}

a, trans := untranspose(a)
ar, ac := a.Dims()
v.reuseAs(r)
var restore func()
if v == a {
v, restore = v.isolatedWorkspace(a.(*VecDense))
if v == aU {
v, restore = v.isolatedWorkspace(aU.(*VecDense))
defer restore()
} else if v == b {
v, restore = v.isolatedWorkspace(b)
defer restore()
}

switch a := a.(type) {
// TODO(kortschak): Improve the non-fast paths.
switch aU := aU.(type) {
case *VecDense:
if v != a {
v.checkOverlap(a.mat)
if v != aU {
v.checkOverlap(aU.mat)
}

if a.Len() == 1 {
if aU.Len() == 1 {
// {1,1} x {1,n}
av := a.At(0, 0)
av := aU.At(0, 0)
if fast {
for i := 0; i < b.Len(); i++ {
v.mat.Data[i*v.mat.Inc] = av * bmat.Data[i*bmat.Inc]
}
return
}
for i := 0; i < b.Len(); i++ {
v.mat.Data[i*v.mat.Inc] = av * b.mat.Data[i*b.mat.Inc]
v.mat.Data[i*v.mat.Inc] = av * b.AtVec(i)
}
return
}
if b.Len() == 1 {
// {1,n} x {1,1}
bv := b.At(0, 0)
for i := 0; i < a.Len(); i++ {
v.mat.Data[i*v.mat.Inc] = bv * a.mat.Data[i*a.mat.Inc]
bv := b.AtVec(0)
for i := 0; i < aU.Len(); i++ {
v.mat.Data[i*v.mat.Inc] = bv * aU.mat.Data[i*aU.mat.Inc]
}
return
}
// {n,1} x {1,n}
var sum float64
for i := 0; i < c; i++ {
sum += a.At(i, 0) * b.At(i, 0)
sum += aU.AtVec(i) * b.AtVec(i)
}
v.SetVec(0, sum)
return
case RawSymmetricer:
amat := a.RawSymmetric()
blas64.Symv(1, amat, b.mat, 0, v.mat)
if fast {
amat := aU.RawSymmetric()
blas64.Symv(1, amat, bmat, 0, v.mat)
return
}
case RawTriangular:
v.CopyVec(b)
amat := a.RawTriangular()
amat := aU.RawTriangular()
ta := blas.NoTrans
if trans {
ta = blas.Trans
}
blas64.Trmv(ta, amat, v.mat)
case RawMatrixer:
amat := a.RawMatrix()
// We don't know that a is a *Dense, so make
// a temporary Dense to check overlap.
(&Dense{mat: amat}).checkOverlap(v.asGeneral())
t := blas.NoTrans
if trans {
t = blas.Trans
if fast {
amat := aU.RawMatrix()
// We don't know that a is a *Dense, so make
// a temporary Dense to check overlap.
(&Dense{mat: amat}).checkOverlap(v.asGeneral())
t := blas.NoTrans
if trans {
t = blas.Trans
}
blas64.Gemv(t, 1, amat, bmat, 0, v.mat)
return
}
blas64.Gemv(t, 1, amat, b.mat, 0, v.mat)
default:
if trans {
col := make([]float64, ar)
for c := 0; c < ac; c++ {
for i := range col {
col[i] = a.At(i, c)
}
if fast {
for i := 0; i < r; i++ {
var f float64
for i, e := range col {
f += e * b.mat.Data[i*b.mat.Inc]
for j := 0; j < c; j++ {
f += a.At(i, j) * bmat.Data[j*bmat.Inc]
}
v.mat.Data[c*v.mat.Inc] = f
}
} else {
row := make([]float64, ac)
for r := 0; r < ar; r++ {
for i := range row {
row[i] = a.At(r, i)
}
var f float64
for i, e := range row {
f += e * b.mat.Data[i*b.mat.Inc]
}
v.mat.Data[r*v.mat.Inc] = f
v.mat.Data[i*v.mat.Inc] = f
}
return
}
}

for i := 0; i < r; i++ {
var f float64
for j := 0; j < c; j++ {
f += a.At(i, j) * b.AtVec(j)
}
v.mat.Data[i*v.mat.Inc] = f
}
}

Expand Down
4 changes: 2 additions & 2 deletions mat/vector_test.go
Expand Up @@ -184,10 +184,10 @@ func TestVecDenseAtSet(t *testing.T) {
func TestVecDenseMul(t *testing.T) {
method := func(receiver, a, b Matrix) {
type mulVecer interface {
MulVec(a Matrix, b *VecDense)
MulVec(a Matrix, b Vector)
}
rd := receiver.(mulVecer)
rd.MulVec(a, b.(*VecDense))
rd.MulVec(a, b.(Vector))
}
denseComparison := func(receiver, a, b *Dense) {
receiver.Mul(a, b)
Expand Down

0 comments on commit e024721

Please sign in to comment.