diff --git a/nn.go b/nn.go index 1f88d088..7f3852a9 100644 --- a/nn.go +++ b/nn.go @@ -438,3 +438,9 @@ func BatchNorm(x, scale, bias *Node, momentum, epsilon float64) (retVal, γ, β return retVal, scale, bias, op, err } + +// GlobalAveragePool2D consumes an input tensor X and applies average pooling across the values in the same channel. +// The expected input shape is BCHW where B is the batch size, C is the number of channels, and H and W are the height and the width of the data. +func GlobalAveragePool2D(x *Node) (*Node, error) { + return ApplyOp(&globalAveragePoolOp{}, x) +} diff --git a/nn_test.go b/nn_test.go index ce11907f..a486d518 100644 --- a/nn_test.go +++ b/nn_test.go @@ -478,3 +478,86 @@ func TestLeakyRelu(t *testing.T) { }) } } + +func TestGlobalAveragePool2D_fwdPass(t *testing.T) { + for _, tst := range []struct { + inputT tensor.Tensor + expectedOutput tensor.Tensor + }{ + { + inputT: tensor.New( + tensor.WithShape(1, 3, 5, 5), + tensor.WithBacking([]float32{ + 1.7640524, 0.4001572, 0.978738, 2.2408931, 1.867558, + -0.9772779, 0.95008844, -0.1513572, -0.10321885, 0.41059852, + 0.14404356, 1.4542735, 0.7610377, 0.121675014, 0.44386324, + 0.33367434, 1.4940791, -0.20515826, 0.3130677, -0.85409576, + -2.5529897, 0.6536186, 0.8644362, -0.742165, 2.2697546, + + -1.4543657, 0.045758516, -0.18718386, 1.5327792, 1.4693588, + 0.15494743, 0.37816253, -0.88778573, -1.9807965, -0.34791216, + 0.15634897, 1.2302907, 1.2023798, -0.3873268, -0.30230275, + -1.048553, -1.420018, -1.7062702, 1.9507754, -0.5096522, + -0.4380743, -1.2527953, 0.7774904, -1.6138978, -0.21274029, + + -0.89546657, 0.3869025, -0.51080513, -1.1806322, -0.028182229, + 0.42833188, 0.06651722, 0.3024719, -0.6343221, -0.36274117, + -0.67246044, -0.35955316, -0.8131463, -1.7262826, 0.17742614, + -0.40178093, -1.6301984, 0.46278226, -0.9072984, 0.051945396, + 0.7290906, 0.12898292, 1.1394007, -1.2348258, 0.40234163})), + expectedOutput: tensor.New( + tensor.WithShape(1, 3, 1, 1), + tensor.WithBacking([]float32{0.47517386, -0.1940553, -0.28326008})), + }, + { + inputT: tensor.New( + tensor.WithShape(1, 3, 5, 5), + tensor.WithBacking([]float64{ + 1.7640524, 0.4001572, 0.978738, 2.2408931, 1.867558, + -0.9772779, 0.95008844, -0.1513572, -0.10321885, 0.41059852, + 0.14404356, 1.4542735, 0.7610377, 0.121675014, 0.44386324, + 0.33367434, 1.4940791, -0.20515826, 0.3130677, -0.85409576, + -2.5529897, 0.6536186, 0.8644362, -0.742165, 2.2697546, + + -1.4543657, 0.045758516, -0.18718386, 1.5327792, 1.4693588, + 0.15494743, 0.37816253, -0.88778573, -1.9807965, -0.34791216, + 0.15634897, 1.2302907, 1.2023798, -0.3873268, -0.30230275, + -1.048553, -1.420018, -1.7062702, 1.9507754, -0.5096522, + -0.4380743, -1.2527953, 0.7774904, -1.6138978, -0.21274029, + + -0.89546657, 0.3869025, -0.51080513, -1.1806322, -0.028182229, + 0.42833188, 0.06651722, 0.3024719, -0.6343221, -0.36274117, + -0.67246044, -0.35955316, -0.8131463, -1.7262826, 0.17742614, + -0.40178093, -1.6301984, 0.46278226, -0.9072984, 0.051945396, + 0.7290906, 0.12898292, 1.1394007, -1.2348258, 0.40234163})), + expectedOutput: tensor.New( + tensor.WithShape(1, 3, 1, 1), + tensor.WithBacking([]float64{0.47517386, -0.1940553, -0.28326008})), + }, + } { + inputT := tst.inputT + expectedOutput := tst.expectedOutput + g := NewGraph() + assert := assert.New(t) + x := NodeFromAny(g, inputT) + output, err := GlobalAveragePool2D(x) + + if err != nil { + t.Fatal(err) + } + m := NewTapeMachine(g) + if err := m.RunAll(); err != nil { + t.Fatalf("%+v", err) + } + defer m.Close() + if len(output.Shape()) != len(expectedOutput.Shape()) { + t.Fatalf("Bad output shape, expected %v, got %v", expectedOutput.Shape(), output.Shape()) + } + for i, d := range output.Shape() { + if expectedOutput.Shape()[i] != d { + t.Fatalf("Bad output shape, expected %v, got %v", expectedOutput.Shape(), output.Shape()) + } + } + assert.InDeltaSlice(expectedOutput.Data(), output.Value().Data(), 1e-6, "the two tensors should be equal.") + } +} diff --git a/op_nn.go b/op_nn.go index 7f3b5a65..d784852a 100644 --- a/op_nn.go +++ b/op_nn.go @@ -14,6 +14,7 @@ import ( "gorgonia.org/vecf64" ) +// Sanity checks var ( _ SDOp = im2colOp{} _ Op = col2imOp{} @@ -21,6 +22,7 @@ var ( _ Op = &maxPoolDiffOp{} _ Op = &BatchNormOp{} _ Op = &batchnormDiffOp{} + _ Op = &globalAveragePoolOp{} ) /* @@ -1625,3 +1627,119 @@ func (op *batchnormDiffOp) f32s(input, inGrad, outGrad *tensor.Dense) (err error return nil } + +type globalAveragePoolOp struct{} + +func (g *globalAveragePoolOp) Arity() int { + return 1 +} + +func (g *globalAveragePoolOp) Type() hm.Type { + a := hm.TypeVariable('a') + t := newTensorType(4, a) + return hm.NewFnType(t, t) +} + +func (g *globalAveragePoolOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { + b, err := inputs[0].DimSize(0) + if err != nil { + return nil, err + } + c, err := inputs[0].DimSize(1) + if err != nil { + return nil, err + } + // check if the shape is correct without doing type inference + if _, err := inputs[0].DimSize(2); err != nil { + return nil, err + } + if _, err := inputs[0].DimSize(3); err != nil { + return nil, err + } + return tensor.Shape{b, c, 1, 1}, nil +} + +func (g *globalAveragePoolOp) Do(inputs ...Value) (Value, error) { + im := inputs[0] + switch im.(type) { + case tensor.Tensor: + v := im.(tensor.Tensor) + B, C, H, W := v.Shape()[0], v.Shape()[1], v.Shape()[2], v.Shape()[3] + s, err := g.InferShape(v.Shape()) + if err != nil { + return nil, err + } + output := tensor.New(tensor.Of(v.Dtype()), tensor.WithShape(s...)) + switch v.Dtype() { + case tensor.Float64: + for b := 0; b < B; b++ { + for c := 0; c < C; c++ { + var sum float64 + for h := 0; h < H; h++ { + for w := 0; w < W; w++ { + val, err := v.At(b, c, h, w) + if err != nil { + return nil, err + } + sum += val.(float64) + } + } + err := output.SetAt(sum/float64(H*W), b, c, 0, 0) + if err != nil { + return nil, err + } + } + } + case tensor.Float32: + for b := 0; b < B; b++ { + for c := 0; c < C; c++ { + var sum float32 + for h := 0; h < H; h++ { + for w := 0; w < W; w++ { + val, err := v.At(b, c, h, w) + if err != nil { + return nil, err + } + sum += val.(float32) + } + } + err := output.SetAt(sum/float32(H*W), b, c, 0, 0) + if err != nil { + return nil, err + } + } + } + default: + return nil, nyi("Global Average Pool", v.Dtype()) + } + + return output, nil + + default: + return nil, nyi("globalAveragePoolOp", inputs) + } +} + +func (g *globalAveragePoolOp) ReturnsPtr() bool { + return false +} + +func (g *globalAveragePoolOp) CallsExtern() bool { + return false +} + +func (g *globalAveragePoolOp) OverwritesInput() int { + return -1 +} + +func (g *globalAveragePoolOp) WriteHash(h hash.Hash) { + fmt.Fprintf(h, "GlobalAveragePool") +} + +func (g *globalAveragePoolOp) Hashcode() uint32 { + return simpleHash(g) +} + +func (g *globalAveragePoolOp) String() string { + return "GlobalAveragePool" +}