Skip to content

Commit

Permalink
Merge cddc4fc into 9fdf854
Browse files Browse the repository at this point in the history
  • Loading branch information
dcu committed Oct 10, 2020
2 parents 9fdf854 + cddc4fc commit 9a3e65c
Show file tree
Hide file tree
Showing 2 changed files with 263 additions and 0 deletions.
224 changes: 224 additions & 0 deletions op_softmax.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package gorgonia

import (
"fmt"
"hash"

"github.com/chewxy/hm"
"github.com/pkg/errors"
"gorgonia.org/tensor"
)

type softmaxOp struct {
shape tensor.Shape
axes []int
}

func newSoftmaxOp(inputShape tensor.Shape, axes ...int) *softmaxOp {
softmaxop := &softmaxOp{
shape: inputShape,
axes: axes,
}

return softmaxop
}

// SoftMax2 - implements the softmax operation
func SoftMax2(x *Node, axis ...int) (*Node, error) {
xShape := x.Shape()
op := newSoftmaxOp(xShape, axis...)

return ApplyOp(op, x)
}

func (op *softmaxOp) Arity() int {
return 1
}

func (op *softmaxOp) ReturnsPtr() bool { return false }

func (op *softmaxOp) CallsExtern() bool { return false }

func (op *softmaxOp) WriteHash(h hash.Hash) {
fmt.Fprintf(h, "Softmax{}()")
}

func (op *softmaxOp) Hashcode() uint32 { return simpleHash(op) }

func (op *softmaxOp) String() string {
return fmt.Sprintf("Softmax{}()")
}

func (op *softmaxOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
s := inputs[0].(tensor.Shape).Clone()
return s, nil
}

func (op *softmaxOp) Type() hm.Type {
a := hm.TypeVariable('a')
t := newTensorType(1, a)

return hm.NewFnType(t, t)
}

func (op *softmaxOp) OverwritesInput() int { return -1 }

func (op *softmaxOp) checkInput(inputs ...Value) (tensor.Tensor, error) {
if err := checkArity(op, len(inputs)); err != nil {
return nil, err
}

var in tensor.Tensor
var ok bool

if in, ok = inputs[0].(tensor.Tensor); !ok {
return nil, errors.Errorf("Expected input to be a tensor")
}

if in.Shape().Dims() != 1 {
return nil, errors.Errorf("Expected input to have 1 dimensions")
}

return in, nil
}

func (op *softmaxOp) Do(inputs ...Value) (retVal Value, err error) {
inputTensor, err := op.checkInput(inputs...)
if err != nil {
return nil, fmt.Errorf("Can't check Softmax input: %w", err)
}

aShape := inputTensor.Shape()
axis := aShape.Dims() - 1 // default: last dim
if aShape.IsColVec() || (aShape.IsVector() && !aShape.IsRowVec()) {
axis = 0
}

if len(op.axes) > 0 {
if op.axes[0] >= axis+1 || op.axes[0] < 0 {
return nil, errors.Errorf("Cannot perform SoftMax on axis %d. Input has shape %v", op.axes[0], aShape)
}

axis = op.axes[0]
}

exp, err := tensor.Exp(inputTensor)
if err != nil {
return nil, fmt.Errorf("error calculating exp for SoftMax: %w", err)
}

sum, err := tensor.Sum(exp, axis)
if err != nil {
return nil, fmt.Errorf("error calculating sum for SoftMax: %w", err)
}

div, err := tensor.Div(exp, sum)
if err != nil {
return nil, fmt.Errorf("error calculating div for SoftMax: %w", err)
}

return div, nil
}

type softmaxDiffOp struct {
}

func newSoftmaxOpDiff() *softmaxDiffOp {
return &softmaxDiffOp{}
}

func (op *softmaxDiffOp) Arity() int {
return 1
}

func (op *softmaxDiffOp) ReturnsPtr() bool { return false }

func (op *softmaxDiffOp) CallsExtern() bool { return false }

func (op *softmaxDiffOp) WriteHash(h hash.Hash) {
fmt.Fprintf(h, "SoftmaxDiff{}()")
}

func (op *softmaxDiffOp) Hashcode() uint32 { return simpleHash(op) }

func (op *softmaxDiffOp) String() string {
return fmt.Sprintf("SoftmaxDiff{}()")
}

func (op *softmaxDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
s := inputs[0].(tensor.Shape).Clone()

return s, nil
}

func (op *softmaxDiffOp) Type() hm.Type {
aType := hm.TypeVariable('a')

ta := newTensorType(1, aType)

return hm.NewFnType(ta, ta) // f(float64) float64
}

func (op *softmaxDiffOp) OverwritesInput() int { return -1 }

func (op *softmaxDiffOp) checkInput(inputs ...Value) (tensor.Tensor, error) {
if err := checkArity(op, len(inputs)); err != nil {
return nil, err
}

var (
in tensor.Tensor

ok bool
)

switch t := inputs[0].(type) {
case *dualValue:
if in, ok = t.Value.(tensor.Tensor); !ok {
return nil, errors.Errorf("input should be a tensor, got %T", inputs[0])
}
case tensor.Tensor:
in = t
default:
return nil, errors.Errorf("input type is not supported, got %T", inputs[0])
}

return in, nil
}

func (op *softmaxDiffOp) Do(inputs ...Value) (Value, error) {
inputTensor, err := op.checkInput(inputs...)
if err != nil {
return nil, fmt.Errorf("Can't check SoftmaxDiff input: %w", err)
}

diag, err := tensor.Diag(inputTensor)
if err != nil {
return nil, fmt.Errorf("softmaxDiffOp.Do error calculating diag: %w", err)
}

sm := inputTensor.Clone().(tensor.Tensor)
sm.Reshape(inputTensor.Shape().TotalSize(), 1)

smT := sm.Clone().(tensor.Tensor)
smT.Transpose()

smDot, err := tensor.Dot(sm, smT)
if err != nil {
return nil, fmt.Errorf("softmaxDiffOp.Do error calculating dot product: %w", err)
}

result, err := tensor.Sub(diag, smDot)
if err != nil {
return nil, fmt.Errorf("softmaxDiffOp.Do error calculating sub: %w", err)
}

return result, nil
}

// ensure it complies with the Op interface
var (
_ Op = &softmaxOp{}

_ Op = &softmaxDiffOp{}
)
39 changes: 39 additions & 0 deletions op_softmax_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package gorgonia

import (
"testing"

"github.com/stretchr/testify/require"
"gorgonia.org/tensor"
)

var testCasesSoftMaxDo = []struct {
input []float64
expected []float64
}{
{
[]float64{0.2094, -1.0, 0.6411, 0.0, -0.3909}, []float64{0.2382105379413429, 0.07107636737487558, 0.36681399568548617, 0.19320559786800362, 0.13069350113029174},
},
{
[]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float64{7.801341612780742e-05, 0.00021206245143623275, 0.0005764455082375902, 0.0015669413501390804, 0.004259388198344144, 0.0115782175399118, 0.031472858344688034, 0.08555209892803112, 0.23255471590259755, 0.6321492583604866},
},
{
[]float64{0.1, 0.1, 0.1}, []float64{0.3333333333333333, 0.3333333333333333, 0.3333333333333333},
},
{
[]float64{-0.1, 0.3, -1.1, 2.7}, []float64{0.05180179352659075, 0.07727919496508177, 0.019056814854240642, 0.8518621966540868},
},
}

func TestSoftmaxDo(t *testing.T) {
c := require.New(t)

for i, testCase := range testCasesSoftMaxDo {
tt := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input))
op := newSoftmaxOp(tt.Shape())

out, err := op.Do(tt)
c.NoError(err, "failed test case: %d", i)
c.Equal(testCase.expected, out.Data(), "failed test case: %d", i)
}
}

0 comments on commit 9a3e65c

Please sign in to comment.