Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dcu committed Oct 12, 2020
1 parent 030cfb1 commit 6435214
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 40 deletions.
2 changes: 2 additions & 0 deletions example_err_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ func Example_errorHandling() {
)),
)),
))

fmt.Printf("nn2: %v\n", nn2)

defer func() {
if r := recover(); r != nil {
fmt.Printf("An error occurs (caught by recover()): %v\n", r)
}
}()

nn2PlusWrong := Must(Add(nn2, wrong2))
_ = nn2PlusWrong

Expand Down
17 changes: 10 additions & 7 deletions example_operations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gorgonia

import (
"fmt"
"log"
"strings"

"gorgonia.org/tensor"
Expand All @@ -29,12 +28,16 @@ func ExampleSoftMax() {
sm := Must(SoftMax(c))
m := NewTapeMachine(g)
if err := m.RunAll(); err != nil {
log.Fatal(err)
panic(err)
}

fmt.Printf("a:\n%v\nsoftmax(a) - along last axis (default behaviour):\n%1.2f", a.Value(), sm1.Value())
fmt.Printf("b:\n%v\nsoftmax(b) - along axis 0:\n%1.2f", b.Value(), sm0.Value())

tmp := fmt.Sprintf("c %v:\n%v\nsoftmax(c) - along last axis (default behaviour) %v:\n%1.2f", c.Value().Shape(), c.Value(), sm.Value().Shape(), sm.Value())

fmt.Println(strings.Replace(tmp, "\n\n\n", "\n\n", -1))

// the requirement to use tmp and strings.Replace is because when Go runs example tests, it strips excess newlines.

// Output:
Expand Down Expand Up @@ -76,12 +79,12 @@ func ExampleConcat() {

z, err := Concat(2, x, y)
if err != nil {
log.Fatal(err)
panic(err)
}

m := NewTapeMachine(g)
if err := m.RunAll(); err != nil {
log.Fatal(err)
panic(err)
}
tmp := fmt.Sprintf("z %v\n%v", z.Value().Shape(), z.Value())
fmt.Println(strings.Replace(tmp, "\n\n", "\n", -1)) // this is because
Expand Down Expand Up @@ -155,18 +158,18 @@ func ExampleUnconcat() {

z, err := Concat(2, x, y)
if err != nil {
log.Fatal(err)
panic(err)
}

unconcats, err := Unconcat(z, 2, 2)
if err != nil {
log.Fatal(err)
panic(err)
}
a, b := unconcats[0], unconcats[1]

m := NewTapeMachine(g)
if err := m.RunAll(); err != nil {
log.Fatal(err)
panic(err)
}
tmp := fmt.Sprintf("a %v\n%v\nb %v\n%v", a.Value().Shape(), a.Value(), b.Value().Shape(), b.Value())
fmt.Println(strings.Replace(tmp, "\n\n", "\n", -1))
Expand Down
66 changes: 33 additions & 33 deletions op_softmax.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gorgonia
import (
"fmt"
"hash"
"os"

"github.com/chewxy/hm"
"github.com/pkg/errors"
Expand Down Expand Up @@ -78,10 +79,6 @@ func (op *softmaxOp) checkInput(inputs ...Value) (tensor.Tensor, error) {
return nil, errors.Errorf("Expected input to be a tensor")
}

if in.Shape().Dims() != 1 {
return nil, errors.Errorf("Expected input to have 1 dimensions")
}

return in, nil
}

Expand Down Expand Up @@ -116,12 +113,40 @@ func (op *softmaxOp) Do(inputs ...Value) (retVal Value, err error) {
return nil, fmt.Errorf("error calculating sum for SoftMax: %w", err)
}

div, err := tensor.Div(exp, sum)
ss := sum.Shape()
dimsDiff := exp.Shape().Dims() - ss.Dims()
if dimsDiff == 0 {
div, err := tensor.Div(exp, sum)
if err != nil {
return nil, fmt.Errorf("error calculating div for SoftMax: %w", err)
}

return div, nil
}

fmt.Fprintf(os.Stderr, "initial sum: %v axis=%d expShape=%v expDims=%d\nDIFF: %d\n", sum, axis, exp.Shape(), exp.Dims(), dimsDiff)

newShape := tensor.Shape(tensor.BorrowInts(ss.Dims() + dimsDiff))
copy(newShape, ss)
copy(newShape[axis+1:], newShape[axis:])
newShape[axis] = 1

fmt.Fprintf(os.Stderr, "new shape: %v\n", newShape)

if err = sum.Reshape(newShape...); err != nil {
return nil, fmt.Errorf("error reshaping sum for SoftMax: %w", err)
}

fmt.Fprintf(os.Stderr, "sum reshaped: \n%v\nshape: %v\n", sum, sum.Shape())

sum, err = tensor.Repeat(sum, axis, exp.Shape()[1:]...)
if err != nil {
return nil, fmt.Errorf("error calculating div for SoftMax: %w", err)
return nil, fmt.Errorf("error repeating sum for SoftMax: %w", err)
}

return div, nil
fmt.Fprintf(os.Stderr, "sum repeated: \n%v\nshape: %v\nexp=\n%v\n", sum, sum.Shape(), exp)

return tensor.Div(exp, sum)
}

// DoDiff calculates the diff and sets its value to the output node. Implementation for ADOp interface.
Expand Down Expand Up @@ -245,32 +270,7 @@ func (op *softmaxDiffOp) Do(inputs ...Value) (Value, error) {
return nil, fmt.Errorf("Can't check SoftmaxDiff input: %w", err)
}

diag := tensor.New(tensor.AsDenseDiag(inputTensor))
sm := inputTensor.Clone().(tensor.Tensor)

err = sm.Reshape(inputTensor.Shape().TotalSize(), 1)
if err != nil {
return nil, fmt.Errorf("softmaxDiffOp.Do error reshaping the value: %w", err)
}

smT := sm.Clone().(tensor.Tensor)

err = smT.T()
if err != nil {
return nil, fmt.Errorf("softmaxDiffOp.Do error transposing the value: %w", err)
}

smDot, err := tensor.MatMul(sm, smT)
if err != nil {
return nil, fmt.Errorf("softmaxDiffOp.Do error calculating dot product: %w", err)
}

result, err := tensor.Sub(diag, smDot)
if err != nil {
return nil, fmt.Errorf("softmaxDiffOp.Do error calculating sub: %w", err)
}

return result, nil
return inputTensor, nil
}

// ensure it complies with the Op interface
Expand Down

0 comments on commit 6435214

Please sign in to comment.