Skip to content

Commit

Permalink
Merge pull request #436 from strigi-form/fix435
Browse files Browse the repository at this point in the history
MaxPool2D checks (#435)
  • Loading branch information
chewxy committed Oct 1, 2020
2 parents 45cf447 + a5a4931 commit 7d6cbd9
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions nn.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,6 @@ func MaxPool2D(x *Node, kernel tensor.Shape, pad, stride []int) (*Node, error) {
xShape := x.Shape()
h, w := xShape[2], xShape[3]
kh, kw := kernel[0], kernel[1]
ph, pw := pad[0], pad[1]

// check shape
if xShape.Dims() != 4 {
Expand All @@ -344,12 +343,36 @@ func MaxPool2D(x *Node, kernel tensor.Shape, pad, stride []int) (*Node, error) {
return nil, errors.Errorf("Expected kernel to have a shape of dimension 2")
}

if h-kh == 0 && ph == 0 {
// checks
for _, s := range stride {
if s <= 0 {
return nil, errors.Errorf("Cannot use strides of less than or equal 0: %v", stride)
}
}

for _, p := range pad {
if p < 0 {
return nil, errors.Errorf("Cannot use padding of less than 0: %v", pad)
}
}

padNorth := pad[0]
padWest := pad[1]
padSouth := pad[0]
padEast := pad[1]
if len(pad) == 4 {
padNorth = pad[0]
padSouth = pad[1]
padWest = pad[2]
padEast = pad[3]
}

if h-kh+padNorth+padSouth < 0 {
// error
return nil, errors.New("Impossible height/kernel/pad combination")
}

if w-kw == 0 && pw == 0 {
if w-kw+padWest+padEast < 0 {
// error
return nil, errors.New("Impossible width/kernel/pad combination")
}
Expand Down

0 comments on commit 7d6cbd9

Please sign in to comment.