Skip to content

Commit

Permalink
Gap operator (#302)
Browse files Browse the repository at this point in the history
* feat(wip): scratch space for a Global Average Pooling operator

* chore: skeleton of the operator

* feat: Global Average Pool
  • Loading branch information
owulveryck authored and chewxy committed Nov 7, 2019
1 parent 6cc7466 commit 9ecd7d0
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 0 deletions.
6 changes: 6 additions & 0 deletions nn.go
Expand Up @@ -438,3 +438,9 @@ func BatchNorm(x, scale, bias *Node, momentum, epsilon float64) (retVal, γ, β

return retVal, scale, bias, op, err
}

// GlobalAveragePool2D consumes an input tensor X and applies average pooling across the values in the same channel.
// The expected input shape is BCHW where B is the batch size, C is the number of channels, and H and W are the height and the width of the data.
func GlobalAveragePool2D(x *Node) (*Node, error) {
return ApplyOp(&globalAveragePoolOp{}, x)
}
83 changes: 83 additions & 0 deletions nn_test.go
Expand Up @@ -478,3 +478,86 @@ func TestLeakyRelu(t *testing.T) {
})
}
}

func TestGlobalAveragePool2D_fwdPass(t *testing.T) {
for _, tst := range []struct {
inputT tensor.Tensor
expectedOutput tensor.Tensor
}{
{
inputT: tensor.New(
tensor.WithShape(1, 3, 5, 5),
tensor.WithBacking([]float32{
1.7640524, 0.4001572, 0.978738, 2.2408931, 1.867558,
-0.9772779, 0.95008844, -0.1513572, -0.10321885, 0.41059852,
0.14404356, 1.4542735, 0.7610377, 0.121675014, 0.44386324,
0.33367434, 1.4940791, -0.20515826, 0.3130677, -0.85409576,
-2.5529897, 0.6536186, 0.8644362, -0.742165, 2.2697546,

-1.4543657, 0.045758516, -0.18718386, 1.5327792, 1.4693588,
0.15494743, 0.37816253, -0.88778573, -1.9807965, -0.34791216,
0.15634897, 1.2302907, 1.2023798, -0.3873268, -0.30230275,
-1.048553, -1.420018, -1.7062702, 1.9507754, -0.5096522,
-0.4380743, -1.2527953, 0.7774904, -1.6138978, -0.21274029,

-0.89546657, 0.3869025, -0.51080513, -1.1806322, -0.028182229,
0.42833188, 0.06651722, 0.3024719, -0.6343221, -0.36274117,
-0.67246044, -0.35955316, -0.8131463, -1.7262826, 0.17742614,
-0.40178093, -1.6301984, 0.46278226, -0.9072984, 0.051945396,
0.7290906, 0.12898292, 1.1394007, -1.2348258, 0.40234163})),
expectedOutput: tensor.New(
tensor.WithShape(1, 3, 1, 1),
tensor.WithBacking([]float32{0.47517386, -0.1940553, -0.28326008})),
},
{
inputT: tensor.New(
tensor.WithShape(1, 3, 5, 5),
tensor.WithBacking([]float64{
1.7640524, 0.4001572, 0.978738, 2.2408931, 1.867558,
-0.9772779, 0.95008844, -0.1513572, -0.10321885, 0.41059852,
0.14404356, 1.4542735, 0.7610377, 0.121675014, 0.44386324,
0.33367434, 1.4940791, -0.20515826, 0.3130677, -0.85409576,
-2.5529897, 0.6536186, 0.8644362, -0.742165, 2.2697546,

-1.4543657, 0.045758516, -0.18718386, 1.5327792, 1.4693588,
0.15494743, 0.37816253, -0.88778573, -1.9807965, -0.34791216,
0.15634897, 1.2302907, 1.2023798, -0.3873268, -0.30230275,
-1.048553, -1.420018, -1.7062702, 1.9507754, -0.5096522,
-0.4380743, -1.2527953, 0.7774904, -1.6138978, -0.21274029,

-0.89546657, 0.3869025, -0.51080513, -1.1806322, -0.028182229,
0.42833188, 0.06651722, 0.3024719, -0.6343221, -0.36274117,
-0.67246044, -0.35955316, -0.8131463, -1.7262826, 0.17742614,
-0.40178093, -1.6301984, 0.46278226, -0.9072984, 0.051945396,
0.7290906, 0.12898292, 1.1394007, -1.2348258, 0.40234163})),
expectedOutput: tensor.New(
tensor.WithShape(1, 3, 1, 1),
tensor.WithBacking([]float64{0.47517386, -0.1940553, -0.28326008})),
},
} {
inputT := tst.inputT
expectedOutput := tst.expectedOutput
g := NewGraph()
assert := assert.New(t)
x := NodeFromAny(g, inputT)
output, err := GlobalAveragePool2D(x)

if err != nil {
t.Fatal(err)
}
m := NewTapeMachine(g)
if err := m.RunAll(); err != nil {
t.Fatalf("%+v", err)
}
defer m.Close()
if len(output.Shape()) != len(expectedOutput.Shape()) {
t.Fatalf("Bad output shape, expected %v, got %v", expectedOutput.Shape(), output.Shape())
}
for i, d := range output.Shape() {
if expectedOutput.Shape()[i] != d {
t.Fatalf("Bad output shape, expected %v, got %v", expectedOutput.Shape(), output.Shape())
}
}
assert.InDeltaSlice(expectedOutput.Data(), output.Value().Data(), 1e-6, "the two tensors should be equal.")
}
}
118 changes: 118 additions & 0 deletions op_nn.go
Expand Up @@ -14,13 +14,15 @@ import (
"gorgonia.org/vecf64"
)

// Sanity checks
var (
_ SDOp = im2colOp{}
_ Op = col2imOp{}
_ Op = &maxPoolOp{}
_ Op = &maxPoolDiffOp{}
_ Op = &BatchNormOp{}
_ Op = &batchnormDiffOp{}
_ Op = &globalAveragePoolOp{}
)

/*
Expand Down Expand Up @@ -1625,3 +1627,119 @@ func (op *batchnormDiffOp) f32s(input, inGrad, outGrad *tensor.Dense) (err error
return nil

}

type globalAveragePoolOp struct{}

func (g *globalAveragePoolOp) Arity() int {
return 1
}

func (g *globalAveragePoolOp) Type() hm.Type {
a := hm.TypeVariable('a')
t := newTensorType(4, a)
return hm.NewFnType(t, t)
}

func (g *globalAveragePoolOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
b, err := inputs[0].DimSize(0)
if err != nil {
return nil, err
}
c, err := inputs[0].DimSize(1)
if err != nil {
return nil, err
}
// check if the shape is correct without doing type inference
if _, err := inputs[0].DimSize(2); err != nil {
return nil, err
}
if _, err := inputs[0].DimSize(3); err != nil {
return nil, err
}
return tensor.Shape{b, c, 1, 1}, nil
}

func (g *globalAveragePoolOp) Do(inputs ...Value) (Value, error) {
im := inputs[0]
switch im.(type) {
case tensor.Tensor:
v := im.(tensor.Tensor)
B, C, H, W := v.Shape()[0], v.Shape()[1], v.Shape()[2], v.Shape()[3]
s, err := g.InferShape(v.Shape())
if err != nil {
return nil, err
}
output := tensor.New(tensor.Of(v.Dtype()), tensor.WithShape(s...))
switch v.Dtype() {
case tensor.Float64:
for b := 0; b < B; b++ {
for c := 0; c < C; c++ {
var sum float64
for h := 0; h < H; h++ {
for w := 0; w < W; w++ {
val, err := v.At(b, c, h, w)
if err != nil {
return nil, err
}
sum += val.(float64)
}
}
err := output.SetAt(sum/float64(H*W), b, c, 0, 0)
if err != nil {
return nil, err
}
}
}
case tensor.Float32:
for b := 0; b < B; b++ {
for c := 0; c < C; c++ {
var sum float32
for h := 0; h < H; h++ {
for w := 0; w < W; w++ {
val, err := v.At(b, c, h, w)
if err != nil {
return nil, err
}
sum += val.(float32)
}
}
err := output.SetAt(sum/float32(H*W), b, c, 0, 0)
if err != nil {
return nil, err
}
}
}
default:
return nil, nyi("Global Average Pool", v.Dtype())
}

return output, nil

default:
return nil, nyi("globalAveragePoolOp", inputs)
}
}

func (g *globalAveragePoolOp) ReturnsPtr() bool {
return false
}

func (g *globalAveragePoolOp) CallsExtern() bool {
return false
}

func (g *globalAveragePoolOp) OverwritesInput() int {
return -1
}

func (g *globalAveragePoolOp) WriteHash(h hash.Hash) {
fmt.Fprintf(h, "GlobalAveragePool")
}

func (g *globalAveragePoolOp) Hashcode() uint32 {
return simpleHash(g)
}

func (g *globalAveragePoolOp) String() string {
return "GlobalAveragePool"
}

0 comments on commit 9ecd7d0

Please sign in to comment.