Skip to content

Commit

Permalink
Added TensorMul method to linalg for f64. Along the way, borrowClone(…
Browse files Browse the repository at this point in the history
…) is implemented
  • Loading branch information
chewxy committed Sep 19, 2016
1 parent 5422c3d commit 87bcc31
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 26 deletions.
50 changes: 26 additions & 24 deletions tensor/f64/arith_linalg_methods.go
@@ -1,8 +1,6 @@
package tensorf64

import (
"log"

"github.com/chewxy/gorgonia/tensor/types"
"github.com/gonum/blas"
"github.com/pkg/errors"
Expand Down Expand Up @@ -355,8 +353,14 @@ func (t *Tensor) outer(other, retVal *Tensor) {
return
}

// TensorMul is for multiplying Tensors with more than 2 dimensions. It exploits a trick in D/SGEMM, and reshaping using the fortran order
// BROKEN. DO NOT USE UNTIL FIXED.
// TensorMul is for multiplying Tensors with more than 2 dimensions.
//
// The algorithm is conceptually simple (but tricky to get right):
// 1. Transpose and reshape the Tensors in such a way that both t and other are 2D matrices
// 2. Use DGEMM to multiply them
// 3. Reshape the results to be the new expected result
//
// This function is a Go implementation of Numpy's tensordot method. It simplifies a lot of what Numpy does.
func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, err error) {
ts := t.Shape()
td := len(ts)
Expand All @@ -382,6 +386,7 @@ func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, e
}
}
}

if !sameLength {
err = shapeMismatchError(ts, os)
return
Expand Down Expand Up @@ -418,11 +423,9 @@ func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, e
retShape1 := types.BorrowInts(len(ts))
defer types.ReturnInts(retShape1)
retShape1 = retShape1[:0]
log.Printf("len(retShape %d", len(retShape1))
for _, ni := range notins {
retShape1 = append(retShape1, ts[ni])
}
log.Printf("NewAxesA: %v | notins: %v | newShapeT: %v | retShape:%v", newAxesA, notins, newShapeT, retShape1)

// work on other now
notins = notins[:0]
Expand All @@ -438,10 +441,11 @@ func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, e
notins = append(notins, i)
}
}

newAxesB := types.BorrowInts(len(notins) + len(axesB))
defer types.ReturnInts(newAxesB)
newAxesB = newAxesB[:0]
newAxesB = append(notins, axesB...)
newAxesB = append(axesB, notins...)

newShapeO := types.Shape(types.BorrowInts(2))
defer types.ReturnInts(newShapeO)
Expand All @@ -453,32 +457,39 @@ func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, e
for _, ni := range notins {
retShape2 = append(retShape2, os[ni])
}
log.Printf("NewAxesB: %v | notins: %v | newShapeO: %v | retShape:%v", newAxesB, notins, newShapeO, retShape2)

if err = t.T(newAxesA...); err != nil {
// we borrowClone because we don't want to touch the original Tensors
doT := t.borrowClone()
doOther := other.borrowClone()
defer ReturnTensor(doT)
defer ReturnTensor(doOther)

if err = doT.T(newAxesA...); err != nil {
return
}
log.Printf("T.strides: %v", t.Strides())
log.Printf("%v", t)
doT.Transpose() // we have to materialize the transpose first or the underlying data won't be changed and the reshape that follows would be meaningless

if err = t.Reshape(newShapeT...); err != nil {
if err = doT.Reshape(newShapeT...); err != nil {
return
}

if err = other.T(newAxesB...); err != nil {
if err = doOther.T(newAxesB...); err != nil {
return
}
doOther.Transpose()

if err = other.Reshape(newShapeO...); err != nil {
if err = doOther.Reshape(newShapeO...); err != nil {
return
}

if retVal, err = t.MatMul(other); err != nil {
// the magic happens here
if retVal, err = doT.MatMul(doOther); err != nil {
return
}

retShape := types.BorrowInts(len(retShape1) + len(retShape2))
defer types.ReturnInts(retShape)

retShape = retShape[:0]
retShape = append(retShape, retShape1...)
retShape = append(retShape, retShape2...)
Expand All @@ -487,14 +498,5 @@ func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, e
return
}

// now reset everything
types.ReturnAP(t.AP)
t.AP = t.old
t.old = nil

types.ReturnAP(other.AP)
other.AP = other.old
other.old = nil

return
}
8 changes: 6 additions & 2 deletions tensor/f64/arith_linalg_methods_test.go
Expand Up @@ -544,7 +544,6 @@ func TestTouter(t *testing.T) {

}

/*
func TestTensorMul(t *testing.T) {
assert := assert.New(t)
var A, B, C *Tensor
Expand All @@ -561,8 +560,13 @@ func TestTensorMul(t *testing.T) {
expectedData = []float64{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}
assert.Equal(expectedData, C.data)
assert.Equal(expectedShape, C.Shape())

// make sure nothing's changed
assert.Equal(types.Shape{3, 4, 5}, A.Shape())
assert.Equal(types.Shape{4, 3, 2}, B.Shape())
assert.Equal(RangeFloat64(0, 60), A.data)
assert.Equal(RangeFloat64(0, 24), B.data)
}
*/

/*
//TODO
Expand Down
1 change: 1 addition & 0 deletions tensor/f64/perf.go
Expand Up @@ -89,5 +89,6 @@ func ReturnTensor(t *Tensor) {
t.transposeWith = nil
}

t.Unlock()
pool.Put(t)
}
16 changes: 16 additions & 0 deletions tensor/f64/tensor.go
Expand Up @@ -241,6 +241,22 @@ func (t *Tensor) Clone() *Tensor {
return retVal
}

func (t *Tensor) borrowClone() *Tensor {
retVal := BorrowTensor(len(t.data))
types.ReturnAP(retVal.AP)
retVal.AP = t.AP.Clone()

if t.old != nil {
retVal.old = t.old.Clone()
}

newdata := make([]float64, len(t.data))
copy(newdata, t.data)
retVal.data = newdata
retVal.Lock()
return retVal
}

func (t *Tensor) IsView() bool {
return t.viewOf != nil
}
Expand Down

0 comments on commit 87bcc31

Please sign in to comment.