Skip to content

Commit

Permalink
Fix Dropout() to return corrrect scaled output (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkKremer committed Mar 25, 2020
1 parent f5fab5f commit 3dc3784
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
15 changes: 9 additions & 6 deletions nn.go
Expand Up @@ -63,6 +63,12 @@ func BinaryXent(output, target *Node) (retVal *Node, err error) {
// It uses randomly zeroes out a *Tensor with a probability drawn from // It uses randomly zeroes out a *Tensor with a probability drawn from
// a uniform distribution // a uniform distribution
func Dropout(x *Node, prob float64) (retVal *Node, err error) { func Dropout(x *Node, prob float64) (retVal *Node, err error) {
return dropout(x, prob, UniformRandomNode)
}

type dropoutRandFn func(g *ExprGraph, dt tensor.Dtype, low, high float64, shape ...int) *Node

func dropout(x *Node, prob float64, randFn dropoutRandFn) (retVal *Node, err error) {
if prob == 0.0 { if prob == 0.0 {
return x, nil return x, nil
} }
Expand All @@ -72,22 +78,19 @@ func Dropout(x *Node, prob float64) (retVal *Node, err error) {
return nil, errors.Wrap(err, dtypeOfFail) return nil, errors.Wrap(err, dtypeOfFail)
} }


var opp, pr Value // opp = 1 per p var pr Value
switch dt { switch dt {
case Float64: case Float64:
opp, _ = anyToScalar(1.0 / prob)
pr, _ = anyToScalar(prob) pr, _ = anyToScalar(prob)
case Float32: case Float32:
opp, _ = anyToScalar(float32(1.0 / prob))
pr, _ = anyToScalar(float32(prob)) pr, _ = anyToScalar(float32(prob))
default: default:
return nil, errors.Errorf(nyiTypeFail, "Dropout()", dt) return nil, errors.Errorf(nyiTypeFail, "Dropout()", dt)
} }


p := NewConstant(pr) p := NewConstant(pr)
c := NewConstant(opp)


m := UniformRandomNode(x.g, dt, 0, 1, x.shape...) m := randFn(x.g, dt, 0, 1, x.shape...)
if retVal, err = Gt(m, p, true); err != nil { if retVal, err = Gt(m, p, true); err != nil {
return nil, errors.Wrap(err, "Greater Than failed") return nil, errors.Wrap(err, "Greater Than failed")
} }
Expand All @@ -96,7 +99,7 @@ func Dropout(x *Node, prob float64) (retVal *Node, err error) {
return nil, errors.Wrap(err, mulFail) return nil, errors.Wrap(err, mulFail)
} }


return HadamardDiv(retVal, c) return HadamardDiv(retVal, p)
} }


// LeakyRelu returns a node whose underlying value is: // LeakyRelu returns a node whose underlying value is:
Expand Down
40 changes: 39 additions & 1 deletion nn_test.go
@@ -1,6 +1,7 @@
package gorgonia package gorgonia


import ( import (
"fmt"
"io/ioutil" "io/ioutil"
"runtime" "runtime"
"testing" "testing"
Expand All @@ -10,6 +11,43 @@ import (
"gorgonia.org/tensor" "gorgonia.org/tensor"
) )


func TestDropout(t *testing.T) {
var tests = []struct {
dt tensor.Dtype
prob float64
rand interface{}
expected interface{}
}{
{Float64, 0.0, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{1.0, 1.0, 1.0, 1.0, 1.0}},
{Float64, 0.2, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{0.0, 0.0, 5.0, 5.0, 5.0}},
{Float64, 0.5, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{0.0, 0.0, 0.0, 2.0, 2.0}},
{Float64, 1.0, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{0.0, 0.0, 0.0, 0.0, 0.0}},
{Float32, 0.2, []float32{0.0, 0.2, 0.5, 0.8, 1.0}, []float32{0.0, 0.0, 5.0, 5.0, 5.0}},
{Float32, 0.5, []float32{0.0, 0.2, 0.5, 0.8, 1.0}, []float32{0.0, 0.0, 0.0, 2.0, 2.0}},
}

for _, tt := range tests {
name := fmt.Sprintf("%v-%.1f", tt.dt, tt.prob)
t.Run(name, func(t *testing.T) {
randFn := func(g *ExprGraph, dt tensor.Dtype, low, high float64, shape ...int) *Node {
return NewVector(g, dt, WithShape(shape...), WithInit(func(dt tensor.Dtype, s ...int) interface{} {
return tt.rand
}))
}
g := NewGraph()
x := NewVector(g, tt.dt, WithShape(5), WithName("x"), WithInit(Ones()))
do := Must(dropout(x, tt.prob, randFn))

m := NewTapeMachine(g, BindDualValues())
defer m.Close()
defer runtime.GC()

assert.NoError(t, m.RunAll())
assert.Equal(t, tt.expected, do.Value().Data())
})
}
}

func dropoutTest(t *testing.T, dt tensor.Dtype) error { func dropoutTest(t *testing.T, dt tensor.Dtype) error {
g := NewGraph() g := NewGraph()
x := NewVector(g, dt, WithShape(10), WithName("x"), WithInit(RangedFrom(0))) x := NewVector(g, dt, WithShape(10), WithName("x"), WithInit(RangedFrom(0)))
Expand Down Expand Up @@ -44,7 +82,7 @@ func dropoutTest(t *testing.T, dt tensor.Dtype) error {
return nil return nil
} }


func TestDropout(t *testing.T) { func TestDropout_integration(t *testing.T) {
// t.Skip() // t.Skip()


if err := dropoutTest(t, Float64); err != nil { if err := dropoutTest(t, Float64); err != nil {
Expand Down

0 comments on commit 3dc3784

Please sign in to comment.