-
Notifications
You must be signed in to change notification settings - Fork 1
/
batchtoshape.go
50 lines (37 loc) · 1.36 KB
/
batchtoshape.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package reshape
import (
"github.com/dereklstinson/gocunets/devices/gpu/nvidia/cudnn"
"github.com/dereklstinson/gocunets/layers"
)
//GetShapetoBatchIO will return the output IO for the S2B op.
func (l *Layer) getbatchtoshapeio(handle *cudnn.Handler, x *layers.Tensor, input bool) (*layers.Tensor, error) {
yfrmt, ydtype, dims, err := l.op.GetB2SOutputProperties(x.Volume, l.stride, l.window)
if err != nil {
return nil, err
}
return layers.CreateTensor(handle, (yfrmt), (ydtype), dims)
}
//GetShapetoBatchIO will return the output IO for the S2B op.
func (l *Layer) getbatchtoshapeioinference(handle *cudnn.Handler, x *layers.Tensor, input bool) (*layers.Tensor, error) {
yfrmt, ydtype, dims, err := l.op.GetB2SOutputProperties(x.Volume, l.stride, l.window)
if err != nil {
return nil, err
}
return layers.CreateTensor(handle, (yfrmt), (ydtype), dims)
}
//SpaceToBatchForwardProp does the forwardpropagation
func (l *Layer) batchtoshapeforwardprop(handle *cudnn.Handler, x, y *layers.Tensor) error {
err := l.op.B2SForward(handle, x.Volume, y.Volume, l.stride)
if err != nil {
return err
}
return nil
}
//SpaceToBatchBackward does the backward propagation
func (l *Layer) batchtoshapebackprop(handle *cudnn.Handler, x, y *layers.Tensor) error {
err := l.op.B2SBackward(handle, x.Volume, y.Volume, l.stride)
if err != nil {
return err
}
return nil
}