diff --git a/tensor/f64/arith_linalg_methods.go b/tensor/f64/arith_linalg_methods.go index 8f6812f0..97447319 100644 --- a/tensor/f64/arith_linalg_methods.go +++ b/tensor/f64/arith_linalg_methods.go @@ -1,8 +1,6 @@ package tensorf64 import ( - "log" - "github.com/chewxy/gorgonia/tensor/types" "github.com/gonum/blas" "github.com/pkg/errors" @@ -355,8 +353,14 @@ func (t *Tensor) outer(other, retVal *Tensor) { return } -// TensorMul is for multiplying Tensors with more than 2 dimensions. It exploits a trick in D/SGEMM, and reshaping using the fortran order -// BROKEN. DO NOT USE UNTIL FIXED. +// TensorMul is for multiplying Tensors with more than 2 dimensions. +// +// The algorithm is conceptually simple (but tricky to get right): +// 1. Transpose and reshape the Tensors in such a way that both t and other are 2D matrices +// 2. Use DGEMM to multiply them +// 3. Reshape the results to be the new expected result +// +// This function is a Go implementation of Numpy's tensordot method. It simplifies a lot of what Numpy does. func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, err error) { ts := t.Shape() td := len(ts) @@ -382,6 +386,7 @@ func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, e } } } + if !sameLength { err = shapeMismatchError(ts, os) return @@ -418,11 +423,9 @@ func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, e retShape1 := types.BorrowInts(len(ts)) defer types.ReturnInts(retShape1) retShape1 = retShape1[:0] - log.Printf("len(retShape %d", len(retShape1)) for _, ni := range notins { retShape1 = append(retShape1, ts[ni]) } - log.Printf("NewAxesA: %v | notins: %v | newShapeT: %v | retShape:%v", newAxesA, notins, newShapeT, retShape1) // work on other now notins = notins[:0] @@ -438,10 +441,11 @@ func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, e notins = append(notins, i) } } + newAxesB := types.BorrowInts(len(notins) + len(axesB)) defer types.ReturnInts(newAxesB) newAxesB = newAxesB[:0] - newAxesB = append(notins, axesB...) + newAxesB = append(axesB, notins...) newShapeO := types.Shape(types.BorrowInts(2)) defer types.ReturnInts(newShapeO) @@ -453,32 +457,39 @@ func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, e for _, ni := range notins { retShape2 = append(retShape2, os[ni]) } - log.Printf("NewAxesB: %v | notins: %v | newShapeO: %v | retShape:%v", newAxesB, notins, newShapeO, retShape2) - if err = t.T(newAxesA...); err != nil { + // we borrowClone because we don't want to touch the original Tensors + doT := t.borrowClone() + doOther := other.borrowClone() + defer ReturnTensor(doT) + defer ReturnTensor(doOther) + + if err = doT.T(newAxesA...); err != nil { return } - log.Printf("T.strides: %v", t.Strides()) - log.Printf("%v", t) + doT.Transpose() // we have to materialize the transpose first or the underlying data won't be changed and the reshape that follows would be meaningless - if err = t.Reshape(newShapeT...); err != nil { + if err = doT.Reshape(newShapeT...); err != nil { return } - if err = other.T(newAxesB...); err != nil { + if err = doOther.T(newAxesB...); err != nil { return } + doOther.Transpose() - if err = other.Reshape(newShapeO...); err != nil { + if err = doOther.Reshape(newShapeO...); err != nil { return } - if retVal, err = t.MatMul(other); err != nil { + // the magic happens here + if retVal, err = doT.MatMul(doOther); err != nil { return } retShape := types.BorrowInts(len(retShape1) + len(retShape2)) defer types.ReturnInts(retShape) + retShape = retShape[:0] retShape = append(retShape, retShape1...) retShape = append(retShape, retShape2...) @@ -487,14 +498,5 @@ func (t *Tensor) TensorMul(other *Tensor, axesA, axesB []int) (retVal *Tensor, e return } - // now reset everything - types.ReturnAP(t.AP) - t.AP = t.old - t.old = nil - - types.ReturnAP(other.AP) - other.AP = other.old - other.old = nil - return } diff --git a/tensor/f64/arith_linalg_methods_test.go b/tensor/f64/arith_linalg_methods_test.go index a4d20210..21d2d890 100644 --- a/tensor/f64/arith_linalg_methods_test.go +++ b/tensor/f64/arith_linalg_methods_test.go @@ -544,7 +544,6 @@ func TestTouter(t *testing.T) { } -/* func TestTensorMul(t *testing.T) { assert := assert.New(t) var A, B, C *Tensor @@ -561,8 +560,13 @@ func TestTensorMul(t *testing.T) { expectedData = []float64{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306} assert.Equal(expectedData, C.data) assert.Equal(expectedShape, C.Shape()) + + // make sure nothing's changed + assert.Equal(types.Shape{3, 4, 5}, A.Shape()) + assert.Equal(types.Shape{4, 3, 2}, B.Shape()) + assert.Equal(RangeFloat64(0, 60), A.data) + assert.Equal(RangeFloat64(0, 24), B.data) } -*/ /* //TODO diff --git a/tensor/f64/perf.go b/tensor/f64/perf.go index a05118f8..f70c4ed6 100644 --- a/tensor/f64/perf.go +++ b/tensor/f64/perf.go @@ -89,5 +89,6 @@ func ReturnTensor(t *Tensor) { t.transposeWith = nil } + t.Unlock() pool.Put(t) } diff --git a/tensor/f64/tensor.go b/tensor/f64/tensor.go index b1358b77..5e076ac3 100644 --- a/tensor/f64/tensor.go +++ b/tensor/f64/tensor.go @@ -241,6 +241,22 @@ func (t *Tensor) Clone() *Tensor { return retVal } +func (t *Tensor) borrowClone() *Tensor { + retVal := BorrowTensor(len(t.data)) + types.ReturnAP(retVal.AP) + retVal.AP = t.AP.Clone() + + if t.old != nil { + retVal.old = t.old.Clone() + } + + newdata := make([]float64, len(t.data)) + copy(newdata, t.data) + retVal.data = newdata + retVal.Lock() + return retVal +} + func (t *Tensor) IsView() bool { return t.viewOf != nil }