Skip to content

Commit

Permalink
Merge pull request #440 from dcu/softmax-op
Browse files Browse the repository at this point in the history
Softmax op
  • Loading branch information
chewxy committed Oct 14, 2020
2 parents aafd21c + 7cd01f4 commit b44c85d
Show file tree
Hide file tree
Showing 9 changed files with 473 additions and 34 deletions.
7 changes: 6 additions & 1 deletion complex_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package gorgonia

import "testing"
import (
"runtime/debug"
"testing"
)

func TestWeirdNetwork(t *testing.T) {
const (
Expand Down Expand Up @@ -138,6 +141,8 @@ func TestWeirdNetwork(t *testing.T) {
for i := 0; i < 2; i++ {
if err = m.RunAll(); err != nil {
t.Errorf("%d %v", i, err)
t.Log(string(debug.Stack()))

break
}

Expand Down
6 changes: 4 additions & 2 deletions example_err_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,22 @@ func Example_errorHandling() {
)),
)),
))

fmt.Printf("nn2: %v\n", nn2)

defer func() {
if r := recover(); r != nil {
fmt.Printf("An error occurs (caught by recover()): %v\n", r)
}
}()

nn2PlusWrong := Must(Add(nn2, wrong2))
_ = nn2PlusWrong

// Output:
// nn: ÷ false(%a, %f) :: Matrix float32
// nn: Softmax{-1}()(%9) :: Matrix float32
// An error occurs: Type inference error. Op: + false. Children: [Matrix float32, Matrix float64], OpType:Matrix a → Matrix a → Matrix a: Unable to unify while inferring type of + false: Unification Fail: float64 ~ float32 cannot be unified
// nn2: ÷ false(%a, %f) :: Matrix float32
// nn2: Softmax{-1}()(%9) :: Matrix float32
// An error occurs (caught by recover()): Type inference error. Op: + false. Children: [Matrix float32, Matrix float64], OpType:Matrix a → Matrix a → Matrix a: Unable to unify while inferring type of + false: Unification Fail: float64 ~ float32 cannot be unified

}
17 changes: 10 additions & 7 deletions example_operations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gorgonia

import (
"fmt"
"log"
"strings"

"gorgonia.org/tensor"
Expand All @@ -29,12 +28,16 @@ func ExampleSoftMax() {
sm := Must(SoftMax(c))
m := NewTapeMachine(g)
if err := m.RunAll(); err != nil {
log.Fatal(err)
panic(err)
}

fmt.Printf("a:\n%v\nsoftmax(a) - along last axis (default behaviour):\n%1.2f", a.Value(), sm1.Value())
fmt.Printf("b:\n%v\nsoftmax(b) - along axis 0:\n%1.2f", b.Value(), sm0.Value())

tmp := fmt.Sprintf("c %v:\n%v\nsoftmax(c) - along last axis (default behaviour) %v:\n%1.2f", c.Value().Shape(), c.Value(), sm.Value().Shape(), sm.Value())

fmt.Println(strings.Replace(tmp, "\n\n\n", "\n\n", -1))

// the requirement to use tmp and strings.Replace is because when Go runs example tests, it strips excess newlines.

// Output:
Expand Down Expand Up @@ -76,12 +79,12 @@ func ExampleConcat() {

z, err := Concat(2, x, y)
if err != nil {
log.Fatal(err)
panic(err)
}

m := NewTapeMachine(g)
if err := m.RunAll(); err != nil {
log.Fatal(err)
panic(err)
}
tmp := fmt.Sprintf("z %v\n%v", z.Value().Shape(), z.Value())
fmt.Println(strings.Replace(tmp, "\n\n", "\n", -1)) // this is because
Expand Down Expand Up @@ -155,18 +158,18 @@ func ExampleUnconcat() {

z, err := Concat(2, x, y)
if err != nil {
log.Fatal(err)
panic(err)
}

unconcats, err := Unconcat(z, 2, 2)
if err != nil {
log.Fatal(err)
panic(err)
}
a, b := unconcats[0], unconcats[1]

m := NewTapeMachine(g)
if err := m.RunAll(); err != nil {
log.Fatal(err)
panic(err)
}
tmp := fmt.Sprintf("a %v\n%v\nb %v\n%v", a.Value().Shape(), a.Value(), b.Value().Shape(), b.Value())
fmt.Println(strings.Replace(tmp, "\n\n", "\n", -1))
Expand Down
11 changes: 6 additions & 5 deletions known_issues_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package gorgonia

import (
"log"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorgonia.org/tensor"
)

Expand Down Expand Up @@ -303,11 +303,12 @@ func TestIssue363(t *testing.T) {
}

func TestIssue368(t *testing.T) {
c := require.New(t)

g := NewGraph()
x := NewTensor(g, Float32, 2, WithShape(2, 5), WithInit(GlorotU(1.0)))

sm, err := SoftMax(x, 1)
if err != nil {
log.Fatal(err)
}
_ = sm
c.NoError(err)
c.NotNil(sm)
}
10 changes: 9 additions & 1 deletion op_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ func (op linAlgBinOp) do(inputs []Value, opts ...tensor.FuncOpt) (retVal Value,
if err = a.T(); err != nil {
return nil, errors.Wrap(err, tFail)
}

// untranspose
defer a.T()
}
Expand All @@ -740,6 +741,7 @@ func (op linAlgBinOp) do(inputs []Value, opts ...tensor.FuncOpt) (retVal Value,
if err = b.T(); err != nil {
return nil, errors.Wrap(err, tFail)
}

// untranspose
defer b.T()
}
Expand All @@ -751,18 +753,24 @@ func (op linAlgBinOp) do(inputs []Value, opts ...tensor.FuncOpt) (retVal Value,
retVal, err = tensor.MatVecMul(a, b, opts...)
case vecDotOperator:
var ret interface{}

if ret, err = tensor.Inner(a, b); err != nil {
return nil, errors.Wrapf(err, "Failed to carry out linalgBinOp operation %v", op)
}

retVal, _ = anyToScalar(ret)
case outerProdOperator:
retVal, err = tensor.Outer(a, b, opts...)
case batchedMatMulOperator:
// checks were done when the op was created
retVal, err = batchedMatMul(a, b, nil, op.transA, op.transB, false)
}
return

if err != nil {
return nil, fmt.Errorf("linAlgBinOp %v %s %v error: %w", a.Shape(), op.āBinaryOperator, b.Shape(), err)
}

return retVal, nil
}

func (op linAlgBinOp) preallocBatchMatMul(incr bool, prealloc Value, inputs ...Value) (retVal Value, err error) {
Expand Down

0 comments on commit b44c85d

Please sign in to comment.