Skip to content
Permalink
Browse files

Fix Dropout() to return corrrect scaled output (#388)

  • Loading branch information
MarkKremer committed Mar 25, 2020
1 parent f5fab5f commit 3dc37840f84e479da7acf4f6618f39e6755a4bdd
Showing with 48 additions and 7 deletions.
  1. +9 −6 nn.go
  2. +39 −1 nn_test.go
15 nn.go
@@ -63,6 +63,12 @@ func BinaryXent(output, target *Node) (retVal *Node, err error) {
// It uses randomly zeroes out a *Tensor with a probability drawn from
// a uniform distribution
func Dropout(x *Node, prob float64) (retVal *Node, err error) {
return dropout(x, prob, UniformRandomNode)
}

type dropoutRandFn func(g *ExprGraph, dt tensor.Dtype, low, high float64, shape ...int) *Node

func dropout(x *Node, prob float64, randFn dropoutRandFn) (retVal *Node, err error) {
if prob == 0.0 {
return x, nil
}
@@ -72,22 +78,19 @@ func Dropout(x *Node, prob float64) (retVal *Node, err error) {
return nil, errors.Wrap(err, dtypeOfFail)
}

var opp, pr Value // opp = 1 per p
var pr Value
switch dt {
case Float64:
opp, _ = anyToScalar(1.0 / prob)
pr, _ = anyToScalar(prob)
case Float32:
opp, _ = anyToScalar(float32(1.0 / prob))
pr, _ = anyToScalar(float32(prob))
default:
return nil, errors.Errorf(nyiTypeFail, "Dropout()", dt)
}

p := NewConstant(pr)
c := NewConstant(opp)

m := UniformRandomNode(x.g, dt, 0, 1, x.shape...)
m := randFn(x.g, dt, 0, 1, x.shape...)
if retVal, err = Gt(m, p, true); err != nil {
return nil, errors.Wrap(err, "Greater Than failed")
}
@@ -96,7 +99,7 @@ func Dropout(x *Node, prob float64) (retVal *Node, err error) {
return nil, errors.Wrap(err, mulFail)
}

return HadamardDiv(retVal, c)
return HadamardDiv(retVal, p)
}

// LeakyRelu returns a node whose underlying value is:
@@ -1,6 +1,7 @@
package gorgonia

import (
"fmt"
"io/ioutil"
"runtime"
"testing"
@@ -10,6 +11,43 @@ import (
"gorgonia.org/tensor"
)

func TestDropout(t *testing.T) {
var tests = []struct {
dt tensor.Dtype
prob float64
rand interface{}
expected interface{}
}{
{Float64, 0.0, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{1.0, 1.0, 1.0, 1.0, 1.0}},
{Float64, 0.2, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{0.0, 0.0, 5.0, 5.0, 5.0}},
{Float64, 0.5, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{0.0, 0.0, 0.0, 2.0, 2.0}},
{Float64, 1.0, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{0.0, 0.0, 0.0, 0.0, 0.0}},
{Float32, 0.2, []float32{0.0, 0.2, 0.5, 0.8, 1.0}, []float32{0.0, 0.0, 5.0, 5.0, 5.0}},
{Float32, 0.5, []float32{0.0, 0.2, 0.5, 0.8, 1.0}, []float32{0.0, 0.0, 0.0, 2.0, 2.0}},
}

for _, tt := range tests {
name := fmt.Sprintf("%v-%.1f", tt.dt, tt.prob)
t.Run(name, func(t *testing.T) {
randFn := func(g *ExprGraph, dt tensor.Dtype, low, high float64, shape ...int) *Node {
return NewVector(g, dt, WithShape(shape...), WithInit(func(dt tensor.Dtype, s ...int) interface{} {
return tt.rand
}))
}
g := NewGraph()
x := NewVector(g, tt.dt, WithShape(5), WithName("x"), WithInit(Ones()))
do := Must(dropout(x, tt.prob, randFn))

m := NewTapeMachine(g, BindDualValues())
defer m.Close()
defer runtime.GC()

assert.NoError(t, m.RunAll())
assert.Equal(t, tt.expected, do.Value().Data())
})
}
}

func dropoutTest(t *testing.T, dt tensor.Dtype) error {
g := NewGraph()
x := NewVector(g, dt, WithShape(10), WithName("x"), WithInit(RangedFrom(0)))
@@ -44,7 +82,7 @@ func dropoutTest(t *testing.T, dt tensor.Dtype) error {
return nil
}

func TestDropout(t *testing.T) {
func TestDropout_integration(t *testing.T) {
// t.Skip()

if err := dropoutTest(t, Float64); err != nil {

0 comments on commit 3dc3784

Please sign in to comment.
You can’t perform that action at this time.