Skip to content

Commit

Permalink
That's it... I think conv2d is done. Probably should write tests for it
Browse files Browse the repository at this point in the history
fixed #13
fixed #14
fixed #8
  • Loading branch information
chewxy committed Jun 2, 2017
1 parent d0be6b1 commit ee4c1a0
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions nn.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,57 @@ func Rectify(x *Node) (retVal *Node, err error) {

return HadamardProd(x, retVal)
}

// Conv2d is a simple 2D convoution, to be used for CPU computation only. If CuDNN is used, use the CUDAConv2D function.
// These are the properties the inputs must fulfil:
//
// im: must have 4D shape. Expected format is BCHW (batch, channel, height, width)
// filter: must have 4D shape.
// kernelShape: shape of the filter kernel
// pad: len(pad) == 2
// stride: len(stride) == 2
func Conv2d(im, filter *Node, kernelShape tensor.Shape, stride, pad []int) (retVal *Node, err error) {
var colIm *Node
if colIm, err = Im2Col(im, kernelShape, pad, stride); err != nil {
return
}

layer := filter.Shape()[0]
kernel := filter.Shape()(1)
row := filter.Shape()[2]
col := filter.Shape()[3]

var flattened *Node
if flattened, err = Reshape(filter, tensor.Shape{layer, kernel * row * col}); err != nil {
return
}

// extract patch
batch := colIm.Shape()[0]
m := colIm.Shape()[1]
n := colIm.Shape()[2]
z := colIm.Shape()[3]

var patch, colImLayer *Node
if patch, err = Reshape(colIm, tensor.Shape{batch * m * n, layer}); err != nil {
return
}

op := linAlgBinOp{
āBinaryOperator: matMulOperator,
transA: false,
transB: true,
}

if colImLayer, err = applyOp(op, patch, flattened); err != nil {
return
}

// now reshape and transpose the values back into the original order
var res *Node
if res, err = Reshape(colImLayer, tensor.Shape{batch, m, n, layer}); err != nil {
return
}

return Transpose(res, 0, 3, 1, 2)
}

0 comments on commit ee4c1a0

Please sign in to comment.