Skip to content

Commit

Permalink
Dropout fix 2 (#399)
Browse files Browse the repository at this point in the history
* Fix Dropout() to return corrrect scaled output

* Use keep prob instead of dropout prob to scale outputs in Dropout

Co-authored-by: Chewxy <chewxy@gmail.com>
  • Loading branch information
MarkKremer and chewxy committed Apr 21, 2020
1 parent 48f01f1 commit bc51c8e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
15 changes: 8 additions & 7 deletions nn.go
Expand Up @@ -62,16 +62,17 @@ func BinaryXent(output, target *Node) (retVal *Node, err error) {
// Dropout is a convenience function to implement dropout.
// It uses randomly zeroes out a *Tensor with a probability drawn from
// a uniform distribution
func Dropout(x *Node, prob float64) (retVal *Node, err error) {
return dropout(x, prob, UniformRandomNode)
func Dropout(x *Node, dropProb float64) (retVal *Node, err error) {
return dropout(x, dropProb, UniformRandomNode)
}

type dropoutRandFn func(g *ExprGraph, dt tensor.Dtype, low, high float64, shape ...int) *Node

func dropout(x *Node, prob float64, randFn dropoutRandFn) (retVal *Node, err error) {
if prob == 0.0 {
func dropout(x *Node, dropProb float64, randFn dropoutRandFn) (retVal *Node, err error) {
if dropProb == 0.0 {
return x, nil
}
keepProb := 1.0 - dropProb

var dt tensor.Dtype
if dt, err = dtypeOf(x.t); err != nil {
Expand All @@ -81,17 +82,17 @@ func dropout(x *Node, prob float64, randFn dropoutRandFn) (retVal *Node, err err
var pr Value
switch dt {
case Float64:
pr, _ = anyToScalar(prob)
pr, _ = anyToScalar(keepProb)
case Float32:
pr, _ = anyToScalar(float32(prob))
pr, _ = anyToScalar(float32(keepProb))
default:
return nil, errors.Errorf(nyiTypeFail, "Dropout()", dt)
}

p := NewConstant(pr)

m := randFn(x.g, dt, 0, 1, x.shape...)
if retVal, err = Gt(m, p, true); err != nil {
if retVal, err = Lt(m, p, true); err != nil {
return nil, errors.Wrap(err, "Greater Than failed")
}

Expand Down
9 changes: 4 additions & 5 deletions nn_test.go
Expand Up @@ -19,11 +19,10 @@ func TestDropout(t *testing.T) {
expected interface{}
}{
{Float64, 0.0, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{1.0, 1.0, 1.0, 1.0, 1.0}},
{Float64, 0.2, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{0.0, 0.0, 5.0, 5.0, 5.0}},
{Float64, 0.5, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{0.0, 0.0, 0.0, 2.0, 2.0}},
{Float64, 1.0, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{0.0, 0.0, 0.0, 0.0, 0.0}},
{Float32, 0.2, []float32{0.0, 0.2, 0.5, 0.8, 1.0}, []float32{0.0, 0.0, 5.0, 5.0, 5.0}},
{Float32, 0.5, []float32{0.0, 0.2, 0.5, 0.8, 1.0}, []float32{0.0, 0.0, 0.0, 2.0, 2.0}},
{Float64, 0.2, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{1.25, 1.25, 1.25, 0.0, 0.0}},
{Float64, 0.5, []float64{0.0, 0.2, 0.5, 0.8, 1.0}, []float64{2.0, 2.0, 0.0, 0.0, 0.0}},
{Float32, 0.2, []float32{0.0, 0.2, 0.5, 0.8, 1.0}, []float32{1.25, 1.25, 1.25, 0.0, 0.0}},
{Float32, 0.5, []float32{0.0, 0.2, 0.5, 0.8, 1.0}, []float32{2.0, 2.0, 0.0, 0.0, 0.0}},
}

for _, tt := range tests {
Expand Down

0 comments on commit bc51c8e

Please sign in to comment.